I am trying to implement the attention mechanism using the CIFAR10 dataset. The idea is to implement the attention layer considering only one head. Therefore, I took as reference the multi-head implementation given here:
and here
But I am monumentally failing. When I run the code I get the following error:
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() takes 1 positional argument but 2 were given
I know it is a fundamental question about programming, but I also think that my implementation is incorrect. I will be glad if anyone could give me some hints.
For reference, I leave the code below.
class ScaledDotProductAttention(nn.Module):
def __init__(self, input_dim, output_dim, attn_dropout=0.1):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.q = nn.Linear(input_dim, output_dim, bias=False)
self.k = nn.Linear(input_dim, output_dim, bias=False)
self.v = nn.Linear(input_dim, output_dim, bias=False)
# print(self.q, self.k, self.v)
# self.dropout = nn.Dropout(attn_dropout)
# print(self.dropout)
self.layer_norm = nn.LayerNorm(input_dim, eps=1e-6)
# print(self.layer_norm)
def forward(self, q, k, v, mask=None):
batch = q.shape[0]
#print(batch)
dim_k, dim_v = self.k, self.v
len_q, len_k, len_v = q.size(1), k.size(1), v.size(1)
q_s = self.q(q).view(batch, dim_k)
k_s = self.k(k).view(batch, dim_k)
v_s = self.v(v).view(batch, dim_v)
print(q_s)
#q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)
attn = torch.matmul(q_s/dim_k , k_s.transpose(-1, -2))/np.sqrt(self.d_k)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
# print(attn)
output = torch.matmul(attn, v_s)
# print(output)
return output, attn
class VGG(nn.Module):
def __init__(self, num_classes=10, attention=False):
super().__init__()
self.num_classes = num_classes
self.attention = attention
vgg16 = models.vgg16(pretrained=True)
self.feature_extractor = vgg16.features
self.avg_pool = vgg16.avgpool
self.clf = vgg16.classifier
self.clf[6] = nn.Linear(in_features=4096, out_features=self.num_classes)
if self.attention:
self.attn_layer = ScaledDotProductAttention(512, 64)
def forward(self, x):
x = self.feature_extractor(x)
if self.attention:
x = self.attn_layer(x)
x = self.avg_pool(x)
x = torch.flatten(x, 1)
x = self.clf(x)
return x