I have a 2D Numpy ndarray, x, that I need to split in square subregions of size s. For each subregion, I want to get the greatest element (which I do), and its position within that subregion (which I can't figure out).
Here is a minimal example:
>>> x = np.random.randint(0, 10, (6,8))
>>> x
array([[9, 4, 8, 9, 5, 7, 3, 3],
[3, 1, 8, 0, 7, 7, 5, 1],
[7, 7, 3, 6, 0, 2, 1, 0],
[7, 3, 9, 8, 1, 6, 7, 7],
[1, 6, 0, 7, 5, 1, 2, 0],
[8, 7, 9, 5, 8, 3, 6, 0]])
>>> h, w = x.shape
>>> s = 2
>>> f = x.reshape(h//s, s, w//s, s)
>>> mx = np.max(f, axis=(1, 3))
>>> mx
array([[9, 9, 7, 5],
[7, 9, 6, 7],
[8, 9, 8, 6]])
For example, the 8 in the lower left corner of mx is the greatest element from subregion [[1,6], [8, 7]] in the lower left corner of x.
What I want is to get an array similar to mx, that keeps the indices of the largest elements, like this:
[[0, 1, 1, 2],
[0, 2, 3, 2],
[2, 2, 2, 2]]
where, for example, the 2 in the lower left corner is the index of 8 in the linear representation of [[1, 6], [8, 7]].
I could do it like this: np.argmax(f[i, :, j, :]) and iterate over i and j, but the speed difference is enormous for large amounts of computation. To give you an idea, I'm trying to use (only) Numpy for max pooling. Basically, I'm asking if there is a faster alternative than what I'm using.