I'm implementing a multi-head self-attention mechanism in PyTorch which is part of Text2Image model that I am trying to build and I'm encountering a runtime error when trying to reshape the output of linear transformations before splitting into multiple heads. The text embeddings from my model have a shape of [32, 26, 768]
, and I'm using 8 attention heads with an embedding size of 768. However, during reshaping, I get an invalid shape error. Here I am providing various blocks(as screenshots) & overall model definition(pasted below) & error message can you guys help me correcting the error
overall model Architecture:
class Text2ImageModel(nn.Module):
def __init__(self, image_size, text_embedding_dim, noise_dim, embed_size, heads, dropout, forward_expansion):
super(Text2ImageModel, self).__init__()
# Initialize tokenizer from pretrained GPT-2 model
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.tokenizer.pad_token = self.tokenizer.eos_token
# Text encoder (GPT-2 or BERT based) - assuming it outputs a text embedding
self.text_encoder = TextEncoder()
# Attention block
self.attention_block = SelfAttention(embed_size=embed_size, heads=heads, dropout=dropout)
# Transformer block
self.transformer_block = TransformerBlock(embed_size=embed_size,
heads=heads,
dropout=dropout,
forward_expansion=forward_expansion)
# Generator and Discriminator
self.generator = Generator(text_embedding_dim=text_embedding_dim, z_dim=noise_dim, img_size=image_size)
self.discriminator = Discriminator(image_size=image_size, text_embedding_dim=text_embedding_dim)
def forward(self, text_input, images, noise, attention_mask=None):
# Tokenize and encode the text input
# tokens = self.tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
text_embeddings = self.text_encoder(text_input)
print(f"Text embeddings shape: {text_embeddings.shape}")
# Apply self-attention to text embeddings
attention_output = self.attention_block(value=text_embeddings, key=text_embeddings, query=text_embeddings, mask=attention_mask)
print(f"Attention output shape: {attention_output.shape}")
# Pass the output of attention through the transformer block
transformer_output= self.transformer_block(value=attention_output, key=attention_output, query=attention_output, mask=attention_mask)
print(f"Transformer output shape: {transformer_output.shape}")
# Generate an image from the transformer output and noise
generated_images = self.generator(transformer_output, noise)
# Discriminator takes real images and the corresponding text embeddings
real_image_discrimination = self.discriminator(images, transformer_output)
# Discriminator also takes the fake images and text embeddings
fake_image_discrimination = self.discriminator(generated_images.detach(), transformer_output)
return generated_images, real_image_discrimination, fake_image_discrimination
- the error is popping up while I am trying to pass the text embeddings to attention module
Text embeddings shape: torch.Size([32, 27, 768])
Before self attention: torch.Size([32, 27, 768]) torch.Size([32, 27, 768]) torch.Size([32, 27, 768])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-147-a568d03c70d5> in <cell line: 61>()
71
72 # Forward pass through the model
---> 73 fake_images, real_preds, fake_preds = model(list(captions), images, noise, attention_mask=None)
74 # Flatten the output for the discriminator
75 real_preds = real_preds.view(-1)
5 frames
<ipython-input-144-6021d3c7122c> in forward(self, value, key, query, mask)
24 print("Before self attention:", value.shape, key.shape, query.shape)
25 # Transform and split for multi-head attention
---> 26 values = self.values(value).view(N, value_len, self.heads, self.head_dim)
27 keys = self.keys(key).view(N, key_len, self.heads, self.head_dim)
28 queries = self.queries(query).view(N, query_len, self.heads, self.head_dim)
RuntimeError: shape '[32, 27, 8, 768]' is invalid for input of size 663552
The expected output shape after the linear layer should match [batch_size, seq_len, heads, head_dim], but instead, I am getting an error indicating a mismatch in total elements. I've confirmed that my sequence length is 26, and I'm not sure where the number 27 is coming from in the error message.
Why is my reshaping operation failing, and how can I correct the shape to be compatible with the expected dimensions for multi-head attention?