2

I`m currently stuck on writing some script in numpy, which main goal is to be efficient (so, vectorization is mandatory).

Let`s assume 3-d array:

arr = [[[0, 0, 0, 0],
        [0, 0, 3, 4],
        [0, 0, 3, 0],
        [0, 2, 3, 0]],

       [[0, 0, 3, 0],
        [0, 0, 0, 0],
        [1, 0, 3, 0],
        [0, 0, 0, 0]],

       [[0, 2, 3, 4],
        [0, 0, 0, 0],
        [0, 0, 3, 4],
        [0, 0, 3, 0]],

       [[0, 0, 3, 4],
        [0, 0, 3, 4],
        [0, 0, 0, 0],
        [0, 0, 0, 0]]]

My goal is to set to dismiss every column which have more than one number other than zero. So, having above matrix the result should be something like:

filtered = [[[0, 0, 0, 0],
             [0, 0, 0, 4],
             [0, 0, 0, 0],
             [0, 2, 0, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [1, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 2, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]]]

I`ve managed to work this around by set of np.count_nonzero, np.repeat and reshape:

indices = np.repeat(np.count_nonzero(a=arr, axis=1), repeats=4, axis=0).reshape(4, 4, 4)
result = indices * a

Which produces good results but looks like missing the point (there is a lot of cryptic matrix shape manipulation only to slice array properly). Furthermore, I`d wish this function to be flexible enough to work out with other axes too (for rows e.g.), resulting:

rows_fil = [[[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 3, 0],
             [0, 0, 0, 0]],

            [[0, 0, 3, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 3, 0]],

            [[0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0],
             [0, 0, 0, 0]]

Is there any "numpy" way to achieve such flexible function?

1 Answer 1

1

Here's a solution to cover a generic axis param -

def mask_nnzcount(a, axis):
    # a is input array
    mask = (a!=0).sum(axis=axis, keepdims=True)>1
    return np.where(mask, 0, a)

The trick really is at keepdims = True which allows us to have a generic solution.

With a 3D array, for your column-fill, that's with axis=1 and for row-fill it's axis=2.

For a generic ndarray, you might want to use axis=-2 for column-fill and axis=-1 for row-fill.

Alternatively, we could also use element-wise multiplication instead at the last step to get the output with a*(~mask). Or get an inverted mask i.e. say inv_mask = (a!=0).sum(axis=axis, keepdims=True)<=1 and then do a*inv_mask.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.