0

Suppose that we have a numpy array a of shape (n, d). For example,

np.random.seed(1)

n, d = 5, 3
a = np.random.randn(n, d)

Now let indices be a (m, n)-shaped array of integer indices that ranges over 0, 1, ... d. That is, this array contains indices that indexes the second dimension of a. For example,

m = 10
indices = np.random.randint(low=0, high=d, size=(m, n))

I would like to use indices to index the second dimension of a in the way that it aligns for each n and batch over m.

My solution is

result = np.vstack([a[i, :][indices[:, i]] for i in range(n)]).T
print(result.shape)
# (10, 5)

Another solution is

np.diagonal(a.T[indices], axis1=1, axis2=2)

but I think my methods are unnecessarily complicated. Do we have any elegant "numpitonic" broadcasting to achieve so, for instance something like a.T[indices]?

Note: the definition of "elegant numpitonic" might be ambigeous. How about let's say, the fastest when m and n are quite large.

2 Answers 2

1

Maybe this one:

np.take_along_axis(a.T, indices, axis=0)

It gives correct results:

np.take_along_axis(a.T, indices, axis=0) == result

output:

array([[ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True],
   [ True,  True,  True,  True,  True]])
Sign up to request clarification or add additional context in comments.

Comments

0

What about:

result = a[np.indices(indices.shape)[1], indices]

or:

result = a[np.tile(np.arange(n), m), indices.ravel()].reshape(m,n)

output:

array([[-0.61175641, -1.07296862,  1.74481176,  1.46210794, -0.3224172 ],
       [ 1.62434536,  0.86540763,  0.3190391 ,  1.46210794, -0.3224172 ],
       [-0.52817175, -2.3015387 , -0.7612069 ,  1.46210794, -0.38405435],
       [ 1.62434536, -1.07296862, -0.7612069 , -0.24937038,  1.13376944],
       [ 1.62434536, -1.07296862, -0.7612069 ,  1.46210794,  1.13376944],
       [ 1.62434536, -1.07296862, -0.7612069 , -2.06014071,  1.13376944],
       [-0.61175641, -1.07296862,  0.3190391 ,  1.46210794,  1.13376944],
       [-0.61175641, -1.07296862, -0.7612069 ,  1.46210794,  1.13376944],
       [ 1.62434536, -1.07296862,  0.3190391 , -2.06014071, -0.38405435],
       [ 1.62434536, -2.3015387 ,  0.3190391 , -2.06014071, -0.3224172 ]])

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.