I _love_ numpy, and I am getting excited about jax, too.
However, I do have one request for it. Getting the argmax of a multi-dimensional array, in terms of the array's dimensions, is difficult for new users.
np.argmax(np.array([[1,2,3],[1,9,3],[1,2,3]])) is 4, rather than (1,1). I understand why, but it seems strange to me that argmax cannot return a value the user can use to index their array.
Having to then feed that `4` into unravel_index() with the array's shape as a parameter seems less elegant than say passing a parameter of "as_index=True" to the argmax.
However, I do have one request for it. Getting the argmax of a multi-dimensional array, in terms of the array's dimensions, is difficult for new users.
np.argmax(np.array([[1,2,3],[1,9,3],[1,2,3]])) is 4, rather than (1,1). I understand why, but it seems strange to me that argmax cannot return a value the user can use to index their array.
Having to then feed that `4` into unravel_index() with the array's shape as a parameter seems less elegant than say passing a parameter of "as_index=True" to the argmax.