I want to know if Pytorch have a slice function (same as tf). In particular, I want to select the orange color rows.
2 Answers
You can use slicing as in numpy. See below
import torch
A = torch.rand((3,5,500))
first_three_rows = A[:, :3, :]
However to get different slices as you asked in the question, you can do
import torch
A = torch.rand((3,5,500))
indices = [2,4,5]
result = torch.cat([A[idx, :index, :] for idx, index in enumerate(indices)] , dim=0)
1 Comment
wrek
the first one, I want 2 rows, second one, I want 4 rows, and the last one, I need 5 rows
This is not currently supported by PyTorch. Slicing that tensor would produce another tensor were each subtensor will have different dimensions.
One way to solve this is by iterating by each subtensor and indexing:
sliced_tensors = []
tensor = [tensor with data]
slices_idx = [[first slices], [second slices] ... [n slices]]
for subtensor, slice_idx in map(tensor, slices_idx):
sliced_tensors.append(subtensor[slices_idx, :])

(2+4+5)x500 == 11x500?