0

Suppose i have a 2D numpy array. Given n, i wish to nulify all elements in the matrix except the top n.

I've tried idx = (-y_pred).argsort(axis=-1)[:, :n] to determine what are the indices of the largest n values, but idx shape is [H,W,n], and i don't understand why.

I've tried -

sorted_list = sorted(y_pred, key=lambda x: x[0], reverse=True)
top_ten = sorted_list[:10]

But it didn't really return top 10 indices.

Is there an efficient way to find top n indices and zero the rest?

EDIT input is a NxM matrix of values, and output is the same matrix of size NxM, such that all values are 0 except in indices that correspond to top 10 values

7
  • What's y_pred.shape? Commented May 12, 2019 at 13:16
  • @Alex Goft mention your input and expected output as well Commented May 12, 2019 at 13:19
  • @Divakar please see edit Commented May 12, 2019 at 13:26
  • Could you show us a minimal representative sample data? Commented May 12, 2019 at 13:27
  • 1
    @kmario23 A view is not possible. There isn't a specific striding pattern there. Commented May 12, 2019 at 14:08

2 Answers 2

1

Here's one approach using numpy.argpartition() based on the idea of How do I get indices of N maximum values in a NumPy array?

# sample input to work with
In [62]: arr = np.random.randint(0, 30, 36).reshape(6, 6)

In [63]: arr
Out[63]: 
array([[ 8, 25, 12, 26, 21, 29],
       [24, 22,  7, 14, 23, 13],
       [ 1, 22, 18, 20, 10, 19],
       [26, 10, 27, 19,  6, 28],
       [17, 28,  9, 13, 11, 12],
       [18, 25, 15, 29, 25, 25]])


# initialize an array filled with zeros
In [59]: nullified_arr = np.zeros_like(arr)
In [64]: top_n = 10

# get top_n indices of `arr`
In [57]: top_n_idxs = np.argpartition(arr.reshape(-1), -top_n)[-top_n:]

# copy `top_n` values to output array
In [60]: nullified_arr.reshape(-1)[top_n_idxs] = arr.reshape(-1)[top_n_idxs]


In [71]: nullified_arr
Out[71]: 
array([[ 0, 25,  0, 26,  0, 29],
       [ 0,  0,  0,  0,  0,  0],
       [ 0,  0,  0,  0,  0,  0],
       [26,  0, 27,  0,  0, 28],
       [ 0, 28,  0,  0,  0,  0],
       [ 0,  0,  0, 29, 25, 25]])
Sign up to request clarification or add additional context in comments.

Comments

1

The following code will nullify an NxM matrix X.

threshold = np.sort(X.ravel())[-n]  # get the nth largest value
idx = X < threshold
X[idx] = 0

Note: this method can return a matrix which has more than n nonzero elements when there are duplicated values.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.