PyTorch RuntimeError: Invalid Shape During Reshaping for Multi-Head Attention

94 views Asked by At

I'm implementing a multi-head self-attention mechanism in PyTorch which is part of Text2Image model that I am trying to build and I'm encountering a runtime error when trying to reshape the output of linear transformations before splitting into multiple heads. The text embeddings from my model have a shape of [32, 26, 768], and I'm using 8 attention heads with an embedding size of 768. However, during reshaping, I get an invalid shape error. Here I am providing various blocks(as screenshots) & overall model definition(pasted below) & error message can you guys help me correcting the error

  • Text Encoder Block Text Encoder Block

  • Generator Block Generator Block

  • Discriminator Block Discriminator Block

  • Attention Mechanism Attention Mechanism

  • Transformer Block Transformer Block

  • overall model Architecture:

class Text2ImageModel(nn.Module):
    def __init__(self, image_size, text_embedding_dim, noise_dim, embed_size, heads, dropout, forward_expansion):
        super(Text2ImageModel, self).__init__()

        # Initialize tokenizer from pretrained GPT-2 model
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Text encoder (GPT-2 or BERT based) - assuming it outputs a text embedding
        self.text_encoder = TextEncoder()

        # Attention block
        self.attention_block = SelfAttention(embed_size=embed_size, heads=heads, dropout=dropout)

        # Transformer block
        self.transformer_block = TransformerBlock(embed_size=embed_size,
                                                  heads=heads,
                                                  dropout=dropout,
                                                  forward_expansion=forward_expansion)

        # Generator and Discriminator
        self.generator = Generator(text_embedding_dim=text_embedding_dim, z_dim=noise_dim, img_size=image_size)
        self.discriminator = Discriminator(image_size=image_size, text_embedding_dim=text_embedding_dim)

    def forward(self, text_input, images, noise, attention_mask=None):
        # Tokenize and encode the text input
        # tokens = self.tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
        text_embeddings = self.text_encoder(text_input)
        print(f"Text embeddings shape: {text_embeddings.shape}") 

        # Apply self-attention to text embeddings
        attention_output = self.attention_block(value=text_embeddings, key=text_embeddings, query=text_embeddings, mask=attention_mask)  
        print(f"Attention output shape: {attention_output.shape}")
        # Pass the output of attention through the transformer block
        transformer_output= self.transformer_block(value=attention_output, key=attention_output, query=attention_output, mask=attention_mask)
        print(f"Transformer output shape: {transformer_output.shape}")
        # Generate an image from the transformer output and noise
        generated_images = self.generator(transformer_output, noise)

        # Discriminator takes real images and the corresponding text embeddings
        real_image_discrimination = self.discriminator(images, transformer_output)
        # Discriminator also takes the fake images and text embeddings
        fake_image_discrimination = self.discriminator(generated_images.detach(), transformer_output)

        return generated_images, real_image_discrimination, fake_image_discrimination
  • the error is popping up while I am trying to pass the text embeddings to attention module
Text embeddings shape: torch.Size([32, 27, 768])
Before self attention: torch.Size([32, 27, 768]) torch.Size([32, 27, 768]) torch.Size([32, 27, 768])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-147-a568d03c70d5> in <cell line: 61>()
     71 
     72         # Forward pass through the model
---> 73         fake_images, real_preds, fake_preds = model(list(captions), images, noise, attention_mask=None)
     74         # Flatten the output for the discriminator
     75         real_preds = real_preds.view(-1)

5 frames
<ipython-input-144-6021d3c7122c> in forward(self, value, key, query, mask)
     24         print("Before self attention:", value.shape, key.shape, query.shape)
     25         # Transform and split for multi-head attention
---> 26         values = self.values(value).view(N, value_len, self.heads, self.head_dim)
     27         keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
     28         queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)

RuntimeError: shape '[32, 27, 8, 768]' is invalid for input of size 663552

The expected output shape after the linear layer should match [batch_size, seq_len, heads, head_dim], but instead, I am getting an error indicating a mismatch in total elements. I've confirmed that my sequence length is 26, and I'm not sure where the number 27 is coming from in the error message.

Why is my reshaping operation failing, and how can I correct the shape to be compatible with the expected dimensions for multi-head attention?

0

There are 0 answers