one head attention mechanism pytorch

596 views Asked by At

I am trying to implement the attention mechanism using the CIFAR10 dataset. The idea is to implement the attention layer considering only one head. Therefore, I took as reference the multi-head implementation given here:

https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/5c0264915ab43485adc576f88971fc3d42b10445/transformer/SubLayers.py

and here

https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/5c0264915ab43485adc576f88971fc3d42b10445/transformer/SubLayers.py

But I am monumentally failing. When I run the code I get the following error:

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() takes 1 positional argument but 2 were given

I know it is a fundamental question about programming, but I also think that my implementation is incorrect. I will be glad if anyone could give me some hints.

For reference, I leave the code below.

class ScaledDotProductAttention(nn.Module):
    def __init__(self, input_dim, output_dim,  attn_dropout=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        
        self.q = nn.Linear(input_dim, output_dim, bias=False)
        self.k = nn.Linear(input_dim, output_dim, bias=False)
        self.v = nn.Linear(input_dim, output_dim, bias=False)
        # print(self.q, self.k, self.v)
        # self.dropout = nn.Dropout(attn_dropout)
        # print(self.dropout)
        self.layer_norm = nn.LayerNorm(input_dim, eps=1e-6)
        # print(self.layer_norm)


    def forward(self, q, k, v, mask=None):

        batch = q.shape[0]
        #print(batch)

        dim_k, dim_v = self.k, self.v
        len_q, len_k, len_v  = q.size(1), k.size(1), v.size(1)

        
        q_s = self.q(q).view(batch, dim_k)
        k_s = self.k(k).view(batch, dim_k)
        v_s = self.v(v).view(batch, dim_v)
        print(q_s)
        #q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)       
        attn = torch.matmul(q_s/dim_k , k_s.transpose(-1, -2))/np.sqrt(self.d_k)
        

        if mask is not None:
           attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        # print(attn)
        output = torch.matmul(attn, v_s)   
        # print(output)
        return output, attn     

  




class VGG(nn.Module):
  def __init__(self, num_classes=10, attention=False):
    super().__init__()

    self.num_classes = num_classes
    self.attention = attention

    vgg16 = models.vgg16(pretrained=True)
    self.feature_extractor = vgg16.features
    self.avg_pool = vgg16.avgpool
    self.clf = vgg16.classifier
    self.clf[6] = nn.Linear(in_features=4096, out_features=self.num_classes)
    
    if self.attention:
      self.attn_layer = ScaledDotProductAttention(512, 64)


  def forward(self, x):
    x = self.feature_extractor(x)

    if self.attention:
      x = self.attn_layer(x)

    x = self.avg_pool(x)
    x = torch.flatten(x, 1)
    x = self.clf(x)

    return x
0

There are 0 answers