Understanding the output dimensionality for torch.nn.MultiheadAttention.forward

154 views Asked by At

I want to implement a cross attention between 2 modalities. In my implementation, I set Q from modality A, and K and V from modality B. Modality A is used for a guidance by using cross attention, and the main operations are done in modality B.

Here is the example of my current implementation:

batch_size = 1
embedding_dims = 128
n_heads = 8
seqlen_A = 100
seqlen_B = 30

q = torch.randn(batch_size, seqlen_A, embedding_dims)
k = torch.randn(batch_size, seqlen_B, embedding_dims)
v = torch.randn(batch_size, seqlen_B, embedding_dims)

attn = torch.nn.MultiheadAttention(embedding_dims, n_heads, batch_first = True)

attn_out, attn_map attn(q,k,v)

And I notice the output dimensionality for attn_out is (1,100,128), which is the same as q's dimensionality, not v's.

My intuition of attention mechanism is that q and k are used to extract the relationship of each other and v is the actual value. That's why I set Q with modality A, only used for guidance, and set K,V with modality B, which I mostly care about. But as attn_out has the same dimensionality as Q, not `V, I am little lost.

Is my understanding about attention mechanism wrong? And how can I possibly implement cross attention of modality A and B, where the output should be same as latent variable of modality B?

As a reference, I have a string documentation for torch.nn.MultiheadAttention.forward below.

Signature:
torch.nn.MultiheadAttention.forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    key_padding_mask: Optional[torch.Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[torch.Tensor] = None,
    average_attn_weights: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
Source:   
    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, attn_mask: Optional[Tensor] = None,
                average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
    Args:
        query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
            or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
            :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
            Queries are compared against key-value pairs to produce the output.
            See "Attention Is All You Need" for more details.
        key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
            or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
            :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
            See "Attention Is All You Need" for more details.
        value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
            ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
            sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
            See "Attention Is All You Need" for more details.
        key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
            to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
            Binary and byte masks are supported.
            For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
            the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
        need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
            Default: ``True``.
        attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
            :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
            :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
            broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
            Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
            corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
            corresponding position is not allowed to attend. For a float mask, the mask values will be added to
            the attention weight.
        average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
            heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
            effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)

    Outputs:
        - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
          :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
          where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
          embedding dimension ``embed_dim``.
        - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
          returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
          :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
          :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
          head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.

        .. note::
            `batch_first` argument is ignored for unbatched inputs.
        """

it says "**attn_output** - Attention outputs of shape :math:(L, E) when input is unbatched :math:(L, N, E) when batch_first=False...", and I don't understand why it is not (S,E) or (S,N,E).

I ran example codes (as mentioned in the question)

1

There are 1 answers

0
Karl On

In QKV attention, the query is the main object we are interested in.

The query can be thought of as the main thing we are interested in, or the input data we want to add information to.

The key is sort of an "addressing mechanism". The key is used to calculate attention weights based on similarity between query and key items. The key determines how much attention each item of the query should give to each item of the value.

The value is the information we want to add to the query. We use attention scores from the query and key to weight and aggregate items from the value.

The attention result is the input query, enriched by information in the value, with the contribution of each item in the value weighted by the key.

For a more tangible example, look at the attention mechanism in the decoder of encoder-decoder transformers. In that context, the query (thing we want to update) comes from the decoder, while the key and value items come from the hidden states of the encoder.