Suppose that we have a numpy array a of shape (n, d). For example,
np.random.seed(1)
n, d = 5, 3
a = np.random.randn(n, d)
Now let indices be a (m, n)-shaped array of integer indices that ranges over 0, 1, ... d. That is, this array contains indices that indexes the second dimension of a. For example,
m = 10
indices = np.random.randint(low=0, high=d, size=(m, n))
I would like to use indices to index the second dimension of a in the way that it aligns for each n and batch over m.
My solution is
result = np.vstack([a[i, :][indices[:, i]] for i in range(n)]).T
print(result.shape)
# (10, 5)
Another solution is
np.diagonal(a.T[indices], axis1=1, axis2=2)
but I think my methods are unnecessarily complicated. Do we have any elegant "numpitonic" broadcasting to achieve so, for instance something like a.T[indices]?
Note: the definition of "elegant numpitonic" might be ambigeous. How about let's say, the fastest when m and n are quite large.