3

I'm writing code to optimize quantities that depend on a variable number of parameters. For the optimization I would like to apply index selecting functions such as numpy.argmax and numpy.argmin across multiple axes at once. Below is the code I'm using right now. Is there a more built-in or efficient approach to perform this task across an arbitrary number of axes that may or may not be sequential?

def nd_arg_axes(func, array, start):
    """Applies an index selecting function over trailing axes from start."""

    n_trail = len(array.shape[start:])  # Number of trailing axes to apply to.

    indices = np.zeros((n_trail,)+array.shape[:start], dtype=np.intp)
    for i in np.ndindex(array.shape[:start]):
        indices[(Ellipsis,)+i] = np.unravel_index(func(array[i]),
                                               array.shape[start:])
    return tuple(indices)

# Test showing nd_arg_axes does indeed return the correct indices.
array = np.arange(27).reshape(3,3,3)
max_js = nd_arg_axes(np.argmax, array, 1)

(array[tuple(np.indices(array))+max_js] ==
np.squeeze(np.apply_over_axes(np.amax, array, axes=[1,2])))

1 Answer 1

2

If you are selecting over trailing axes, you can reshape the trailing axes to -1, and apply func to axis=-1:

def f(func, array, start):
    shape = array.shape
    tmp = array.reshape(shape[:start] + (-1,))
    indices = func(tmp, axis=-1)
    return np.unravel_index(indices, shape[start:])
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.