I'm trying to implement the 1D self-attention block below using PyTorch
:
proposed in the following paper. Below you can find my (provisional) attempt:
import torch.nn as nn
import torch
#INPUT shape ((B), CH, H, W)
class Self_Attention1D(nn.Module):
def __init__(self, in_channels=1, out_channels=3):
super().__init__()
self.pointwise_conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
self.pointwise_conv2 = nn.Conv1d(in_channels=out_channels, out_channels=in_channels, kernel_size=(1,1))
self.phi = MLP(in_size = out_channels, out_size=32)
self.psi = MLP(in_size = out_channels, out_size=32)
self.gamma = MLP(in_size=32, out_size=out_channels)
def forward(self, x):
x = self.pointwise_conv1(x)
phi = self.phi(x.transpose(1,3))
psi = self.psi(x.transpose(1,3))
delta = phi-psi
gamma = self.gamma(delta).transpose(3,1)
out = self.pointwise_conv2(torch.mul(gamma,x))
return out
class MLP(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.in_size = in_size
self.out_size = out_size
self.layers = nn.Sequential(
nn.Linear(in_size, 64),
nn.ReLU(),
nn.Linear(64,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,out_size))
def forward(self, x):
out = self.layers(x)
return out
I'm not sure at all that this is correct, as the operations in my implementation are happening globally while as displayed in the image we should compute some operation between each entry and its neighbours one at a time. I was initially tempted to instantiate a for loop to iteratively compute the neural networks delta,phi,psi
for each entry, but I felt that it wasn't the right way to do that.
Apologies if this is trivial but I still don't have a huge experience in PyTorch
.