How can I trim / remove part of a Tensor to match the shape of another Tensor with PyTorch?

4.2k views Asked by At

I have 2 tensors:

outputs: torch.Size([4, 27, 161])       pred: torch.Size([4, 30, 161])

I want to cut pred (from the end) so that it'll have the same dimensions as outputs.

What's the best way to do it with PyTorch?

2

There are 2 answers

0
Felipe Curti On BEST ANSWER

You can use Narrow

e.g:

a = torch.randn(4,30,161)
a.size() # torch.Size([4, 30, 161])
a.narrow(1,0,27).size() # torch.Size([4, 27, 161])
0
Tengerye On

If you have the fix number of dimensions of the two tensors, you can try this:

a = torch.randn(3, 5)
b = torch.zeros(3, 2)
b_h, b_w = b.shape
c = a[:b_h, :b_w]  # torch.Size([3, 2])

The c has the same shape as b but same values from a.