RuntimeError with PyTorch's MultiheadAttention: How to resolve shape mismatch?

22 views Asked by At

I'm encountering an issue regarding the input shape for PyTorch's MultiheadAttention. I have initialized MultiheadAttention as follows: attention = MultiheadAttention(embed_dim=1536, num_heads=4)

The input tensors have the following shapes:

  • query.shape is torch.Size([1, 1, 1536])
  • Both key.shape and value.shape are torch.Size([1, 23, 1536])

However, when attempting to use these inputs, I encounter the following error:

RuntimeError                              Traceback (most recent call last)
Cell In[15], line 1
----> 1 _ = cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)

File ~/main/reproduct/choi/make_embedding.py:384, in cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)
    381 print(embedding.shape)
    383 # attention
--> 384 output, attn_weights = attention(thumbnail, embedding, embedding)
    385 # attn_weight shape: (1, 1, j+1)
    387 attn_weights = attn_weights.squeeze(0).unsqueeze(-1)  # shape: (j+1, 1)

File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/activation.py:1205, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)
   1191     attn_output, attn_output_weights = F.multi_head_attention_forward(
   1192         query, key, value, self.embed_dim, self.num_heads,
...
   5281     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
   5282     assert static_k.size(0) == bsz * num_heads, \
   5283         f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"

RuntimeError: shape '[1, 4, 384]' is invalid for input of size 35328

Why am I encountering this error?

The main execution environment is as follows:

  • Ubuntu 20.04
  • Anaconda 1.7.2
  • Python 3.8.5
  • VSCode 1.87.2
  • PyTorch 2.0.1

Thank you for your cooperation in advance.

1

There are 1 answers

1
Karl On

You need to change

attention = MultiheadAttention(embed_dim=1536, num_heads=4)

to

attention = MultiheadAttention(embed_dim=1536, num_heads=4, batch_first=True)

The default behavior of batch_first=False is making the computation think your query batch size doesn't match your k/v batch size.