3

I need to keep the max N (3) values per row in an Array.

a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])
a
Out[135]: 
array([[1, 2, 3, 4],
       [8, 7, 6, 5],
       [5, 3, 1, 2]])

The indexes of those can be identified with np.partition:

n=3
np.argpartition(a, -n, axis=1)[:,-n:]
Out[136]: 
array([[1, 2, 3],
       [2, 1, 0],
       [3, 0, 1]], dtype=int64)

So, my question is: How should I keep values from those indices and set to zero others to get:

Out[136]: 
array([[0, 2, 3, 4],
       [8, 7, 6, 0],
       [5, 3, 0, 2]])

2 Answers 2

2
a=np.array([[1,2,3,4],[8,7,6,5],[5,3,1,2]])

n=3
mask = np.argpartition(a, -n, axis=1) < a.shape[1] - n

a[mask] = 0
Sign up to request clarification or add additional context in comments.

Comments

0

One option is to use your indices as fancy indices to overwrite the values in a zero array:

import numpy as np

# input
a = np.array([[1,2,3,4], [8,7,6,5], [5,3,1,2]])
n = 3

indices = (np.arange(a.shape[0]), np.argpartition(a, -n, axis=1)[:,-n:])
res = np.zeros_like(a)
res[indices] = a[indices]

Then you get

>>> res
array([[0, 2, 3, 4],
       [8, 7, 6, 0],
       [5, 3, 0, 2]])

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.