I want to insert several multi-head attention layers into a Pretrained EfficientnetB0 model using pytorch. After each sequential block, I want to add a multi-head attention layer.
I tried to do this by https://www.kaggle.com/code/vikramsandu/efficientnet-from-scratch generating the model from scratch. Then tried to add new layers to it. But it didn't work. I wrote the multi-head attention layers as follows.
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, num_heads):
super(MultiHeadAttention, self).__init__()
self.input_dim=input_dim
self.num_heads = num_heads
self.head_dim = input_dim // num_heads
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(input_dim, 3 * input_dim, bias=False)
# output projection
self.c_proj = nn.Linear(input_dim, input_dim, bias=False)
def forward(self, x):
z=x.size()
print(z)
print(x.size())
x = torch.reshape(x, (x.size(0), -1, x.size(1)))
batch_size, seq_len, c = x.size()
print("batch size",batch_size)
print("sequenceh size",seq_len)
print("embd",c)
#q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # B x H x L x D/H
#k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # B x H x L x D/H
#v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # B x H x L x D/H
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.input_dim, dim=2)
#head_size = C// self.n_head
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
attention_output = torch.reshape(y, (z[0], z[1], z[2], z[3]))
return attention_output