3

Say I have a tensor and index:

x = torch.tensor([1,2,3,4,5])
idx = torch.tensor([0,2,4])

If I want to select all elements not in the index, I can manually define a Boolean mask like so:

mask = torch.ones_like(x)
mask[idx] = 0

x[mask]

is there a more elegant way of doing this?

i.e. a syntax where I can directly pass the indices as opposed to creating a mask e.g. something like:

x[~idx]
0

2 Answers 2

2

I couldn't find a satisfactory solution to finding the complement of a multi-dimensional tensor of indices and finally implemented my own. It can work on cuda and enjoys fast parallel computation.

def complement_idx(idx, dim):
    """
    Compute the complement: set(range(dim)) - set(idx).
    idx is a multi-dimensional tensor, find the complement for its trailing dimension,
    all other dimension is considered batched.
    Args:
        idx: input index, shape: [N, *, K]
        dim: the max index for complement
    """
    a = torch.arange(dim, device=idx.device)
    ndim = idx.ndim
    dims = idx.shape
    n_idx = dims[-1]
    dims = dims[:-1] + (-1, )
    for i in range(1, ndim):
        a = a.unsqueeze(0)
    a = a.expand(*dims)
    masked = torch.scatter(a, -1, idx, 0)
    compl, _ = torch.sort(masked, dim=-1, descending=False)
    compl = compl.permute(-1, *tuple(range(ndim - 1)))
    compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,)))
    return compl

Example:

>>> import torch
>>> a = torch.rand(3, 4, 5)
>>> a
tensor([[[0.7849, 0.7404, 0.4112, 0.9873, 0.2937],
         [0.2113, 0.9923, 0.6895, 0.1360, 0.2952],
         [0.9644, 0.9577, 0.2021, 0.6050, 0.7143],
         [0.0239, 0.7297, 0.3731, 0.8403, 0.5984]],

        [[0.9089, 0.0945, 0.9573, 0.9475, 0.6485],
         [0.7132, 0.4858, 0.0155, 0.3899, 0.8407],
         [0.2327, 0.8023, 0.6278, 0.0653, 0.2215],
         [0.9597, 0.5524, 0.2327, 0.1864, 0.1028]],

        [[0.2334, 0.9821, 0.4420, 0.1389, 0.2663],
         [0.6905, 0.2956, 0.8669, 0.6926, 0.9757],
         [0.8897, 0.4707, 0.5909, 0.6522, 0.9137],
         [0.6240, 0.1081, 0.6404, 0.1050, 0.6413]]])
>>> b, c = torch.topk(a, 2, dim=-1)
>>> b
tensor([[[0.9873, 0.7849],
         [0.9923, 0.6895],
         [0.9644, 0.9577],
         [0.8403, 0.7297]],

        [[0.9573, 0.9475],
         [0.8407, 0.7132],
         [0.8023, 0.6278],
         [0.9597, 0.5524]],

        [[0.9821, 0.4420],
         [0.9757, 0.8669],
         [0.9137, 0.8897],
         [0.6413, 0.6404]]])
>>> c
tensor([[[3, 0],
         [1, 2],
         [0, 1],
         [3, 1]],

        [[2, 3],
         [4, 0],
         [1, 2],
         [0, 1]],

        [[1, 2],
         [4, 2],
         [4, 0],
         [4, 2]]])
>>> compl = complement_idx(c, 5)
>>> compl
tensor([[[1, 2, 4],
         [0, 3, 4],
         [2, 3, 4],
         [0, 2, 4]],

        [[0, 1, 4],
         [1, 2, 3],
         [0, 3, 4],
         [2, 3, 4]],

        [[0, 3, 4],
         [0, 1, 3],
         [1, 2, 3],
         [0, 1, 3]]])
>>> al = torch.cat([c, compl], dim=-1)
>>> al
tensor([[[3, 0, 1, 2, 4],
         [1, 2, 0, 3, 4],
         [0, 1, 2, 3, 4],
         [3, 1, 0, 2, 4]],

        [[2, 3, 0, 1, 4],
         [4, 0, 1, 2, 3],
         [1, 2, 0, 3, 4],
         [0, 1, 2, 3, 4]],

        [[1, 2, 0, 3, 4],
         [4, 2, 0, 1, 3],
         [4, 0, 1, 2, 3],
         [4, 2, 0, 1, 3]]])
>>> al, _ = al.sort(dim=-1)
>>> al
tensor([[[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]],

        [[0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4],
         [0, 1, 2, 3, 4]]])
Sign up to request clarification or add additional context in comments.

Comments

1

You may want to try the single-line expression:

x[np.setdiff1d(range(len(x)), idx)]

Though it seems also not elegant:).

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.