I want to implement a cross attention between 2 modalities. In my implementation, I set Q from modality A, and K and V from modality B. Modality A is used for a guidance by using cross attention, and the main operations are done in modality B.
Here is the example of my current implementation:
batch_size = 1
embedding_dims = 128
n_heads = 8
seqlen_A = 100
seqlen_B = 30
q = torch.randn(batch_size, seqlen_A, embedding_dims)
k = torch.randn(batch_size, seqlen_B, embedding_dims)
v = torch.randn(batch_size, seqlen_B, embedding_dims)
attn = torch.nn.MultiheadAttention(embedding_dims, n_heads, batch_first = True)
attn_out, attn_map attn(q,k,v)
And I notice the output dimensionality for attn_out is (1,100,128)
, which is the same as q
's dimensionality, not v
's.
My intuition of attention mechanism is that q and k are used to extract the relationship of each other and v is the actual value. That's why I set Q with modality A, only used for guidance, and set K,V with modality B, which I mostly care about. But as attn_out
has the same dimensionality as Q
, not `V, I am little lost.
Is my understanding about attention mechanism wrong? And how can I possibly implement cross attention of modality A and B, where the output should be same as latent variable of modality B?
As a reference, I have a string documentation for torch.nn.MultiheadAttention.forward below.
Signature:
torch.nn.MultiheadAttention.forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
average_attn_weights: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]
Source:
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
Queries are compared against key-value pairs to produce the output.
See "Attention Is All You Need" for more details.
key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
:math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
See "Attention Is All You Need" for more details.
value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
See "Attention Is All You Need" for more details.
key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
Binary and byte masks are supported.
For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
Default: ``True``.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
:math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
:math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
corresponding position is not allowed to attend. For a float mask, the mask values will be added to
the attention weight.
average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
Outputs:
- **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
:math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
embedding dimension ``embed_dim``.
- **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
:math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
it says "**attn_output** - Attention outputs of shape :math:(L, E)
when input is unbatched :math:(L, N, E)
when batch_first=False
...", and I don't understand why it is not (S,E) or (S,N,E).
I ran example codes (as mentioned in the question)
In QKV attention, the
query
is the main object we are interested in.The
query
can be thought of as the main thing we are interested in, or the input data we want to add information to.The
key
is sort of an "addressing mechanism". Thekey
is used to calculate attention weights based on similarity betweenquery
andkey
items. Thekey
determines how muchattention
each item of thequery
should give to each item of thevalue
.The
value
is the information we want to add to thequery
. We use attention scores from thequery
andkey
to weight and aggregate items from thevalue
.The attention result is the input
query
, enriched by information in thevalue
, with the contribution of each item in thevalue
weighted by thekey
.For a more tangible example, look at the attention mechanism in the decoder of encoder-decoder transformers. In that context, the
query
(thing we want to update) comes from the decoder, while thekey
andvalue
items come from the hidden states of the encoder.