How to insert a multi head attention layer into a pretrained EfficientnetB0 model using pytorch

407 views Asked by At

I want to insert several multi-head attention layers into a Pretrained EfficientnetB0 model using pytorch. After each sequential block, I want to add a multi-head attention layer.

I tried to do this by https://www.kaggle.com/code/vikramsandu/efficientnet-from-scratch generating the model from scratch. Then tried to add new layers to it. But it didn't work. I wrote the multi-head attention layers as follows.

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.input_dim=input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads


   



        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(input_dim, 3 * input_dim, bias=False)
        # output projection
        self.c_proj = nn.Linear(input_dim, input_dim, bias=False)



    def forward(self, x):
        z=x.size()
        print(z)
        print(x.size())
        x = torch.reshape(x, (x.size(0), -1, x.size(1)))
        batch_size, seq_len, c = x.size()
        print("batch size",batch_size)
        print("sequenceh size",seq_len)
        print("embd",c)
        #q = self.q_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # B x H x L x D/H
        #k = self.k_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # B x H x L x D/H
        #v = self.v_linear(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # B x H x L x D/H

        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v = self.c_attn(x).split(self.input_dim, dim=2)
        #head_size = C// self.n_head
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, nh, T, hs)

        # efficient attention using Flash Attention CUDA kernels
        y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.c_proj(y)



        attention_output = torch.reshape(y, (z[0], z[1], z[2], z[3]))

        return attention_output
0

There are 0 answers