Mapping a higher dimension tensor into a lower one: (B, F, D) -> (B, F-n, D) in PyTorch

35 views Asked by At

I have a tensor of embeddings that I want to reduce into a smaller number of embeddings. I am working in a batched environment. The tensor shape is B, F, D where B is the number of items in batch, F is the number of embeddings and D is the dimension. I want to learn a reduction to B, F-n, D.

e.g.

import torch

B = 10
F = 20
F_desired = 17
D = 64

x = torch.randn(B, F, D)
# torch.Size([50, 20, 64])

reduction = torch.?

y = reduction(x)

print(y.shape)
# torch.Size([50, 20, 64])

I think a 1x1 convolution would make sense here, but not sure how to confirm it was actually doing what I expected? So would love to hear if it's the right approach / if there are better approaches

reduction = torch.nn.Conv1d(
    in_channels=F,
    out_channels=F_desired,
    kernel_size=1,
)
1

There are 1 answers

0
Karl On BEST ANSWER

A 1d conv with a kernel size of 1 accomplishes this:

B = 10
F = 20
F_desired = 17
D = 64

x = torch.randn(B, F, D)

reduction1 = nn.Conv1d(F, F_desired, 1)
x1 = reduction1(x)
print(x1.shape)
> torch.Size([10, 17, 64])

You could also do a linear layer, provided you permute the axes:

reduction2 = nn.Linear(F, F_desired)
x2 = reduction2(x.permute(0,2,1)).permute(0,2,1)
print(x2.shape)
> torch.Size([10, 17, 64])

Note that if your convolution kernel is size 1, these are actually equivalent operations

reduction2.weight.data = reduction1.weight.squeeze().data
reduction2.bias.data = reduction1.bias.data

x2 = reduction2(x.permute(0,2,1)).permute(0,2,1)
print(torch.allclose(x1,x2, atol=1e-6))
> True