Implementing 1D self attention in PyTorch

2.9k views Asked by At

I'm trying to implement the 1D self-attention block below using PyTorch:

enter image description here

proposed in the following paper. Below you can find my (provisional) attempt:

import torch.nn as nn
import torch

#INPUT shape ((B), CH, H, W)


class Self_Attention1D(nn.Module):
    
    def __init__(self, in_channels=1, out_channels=3):
        
        super().__init__()
        
        self.pointwise_conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
        
        self.pointwise_conv2 = nn.Conv1d(in_channels=out_channels, out_channels=in_channels, kernel_size=(1,1))
        
        self.phi = MLP(in_size = out_channels, out_size=32)
        
        self.psi = MLP(in_size = out_channels, out_size=32)
                
        self.gamma = MLP(in_size=32, out_size=out_channels)
                
    def forward(self, x):
                
        x = self.pointwise_conv1(x)
        
        phi = self.phi(x.transpose(1,3))
        
        psi = self.psi(x.transpose(1,3))
        
        delta = phi-psi
        
        gamma = self.gamma(delta).transpose(3,1)
        
        out = self.pointwise_conv2(torch.mul(gamma,x))
        
        return out



class MLP(nn.Module):
    
    def __init__(self, in_size, out_size):
        
        super().__init__()
        
        self.in_size = in_size
        self.out_size = out_size
        
        self.layers = nn.Sequential(
            
            nn.Linear(in_size, 64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,out_size))
    
    def forward(self, x):
        
        out = self.layers(x)
        
        return out

I'm not sure at all that this is correct, as the operations in my implementation are happening globally while as displayed in the image we should compute some operation between each entry and its neighbours one at a time. I was initially tempted to instantiate a for loop to iteratively compute the neural networks delta,phi,psi for each entry, but I felt that it wasn't the right way to do that.

Apologies if this is trivial but I still don't have a huge experience in PyTorch.

0

There are 0 answers