24

I have a tensor X like [0.1, 0.5, -1.0, 0, 1.2, 0], and I want to implement a function called filter_positive(), it can filter the positive data into a new tensor and return the index of the original tensor. For example:

new_tensor, index = filter_positive(X)

new_tensor = [0.1, 0.5, 1.2]
index = [0, 1, 4]

How can I implement this function most efficiently in pytorch?

2 Answers 2

32

Take a look at torch.nonzero which is roughly equivalent to np.where. It translates a binary mask to indices:

>>> X = torch.tensor([0.1, 0.5, -1.0, 0, 1.2, 0])
>>> mask = X >= 0
>>> mask
tensor([1, 1, 0, 1, 1, 1], dtype=torch.uint8)

>>> indices = torch.nonzero(mask)
>>> indices
tensor([[0],
        [1],
        [3],
        [4],
        [5]])

>>> X[indices]
tensor([[0.1000],
        [0.5000],
        [0.0000],
        [1.2000],
        [0.0000]])

A solution would then be to write:

mask = X >= 0
new_tensor = X[mask]
indices = torch.nonzero(mask)
Sign up to request clarification or add additional context in comments.

Comments

9

If the index is not necessary, you could just do:

X = X[X > 0]

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.