Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

I _love_ numpy, and I am getting excited about jax, too.

However, I do have one request for it. Getting the argmax of a multi-dimensional array, in terms of the array's dimensions, is difficult for new users.

np.argmax(np.array([[1,2,3],[1,9,3],[1,2,3]])) is 4, rather than (1,1). I understand why, but it seems strange to me that argmax cannot return a value the user can use to index their array.

Having to then feed that `4` into unravel_index() with the array's shape as a parameter seems less elegant than say passing a parameter of "as_index=True" to the argmax.




Consider this:

  In [1]: np.argmax(np.array([[1,2,3],[1,9,3],[1,2,3]]).flat)
  Out[2]: 4


Alternatively you could use flat:

a = np.array([[1,2,3],[1,9,3],[1,2,3]])

idx = np.argmax(a)

a.flat[idx] # 9


Does that work the same way with strided arrays?


Assuming you mean what I think you mean, it does work.

e.g. a[::2, ::3].flat[idx], where idx is from 0 to width*height of the view

(idx can also be a NumPy array, for getting multiple values)




Consider applying for YC's Fall 2025 batch! Applications are open till Aug 4

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: