AttentionQKV from Trax

352 views Asked by At

The AttentionQKV layer implemented by Trax is as the following: AttentionQKV

def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
  """Returns a layer that maps (q, k, v, mask) to (activations, mask).
  See `Attention` above for further context/details.
  Args:
    d_feature: Depth/dimensionality of feature embedding.
    n_heads: Number of attention heads.
    dropout: Probababilistic rate for internal dropout applied to attention
        activations (based on query-key pairs) before dotting them with values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
  """
  return cb.Serial(
      cb.Parallel(
          core.Dense(d_feature),
          core.Dense(d_feature),
          core.Dense(d_feature),
      ),
      PureAttention(  # pylint: disable=no-value-for-parameter
          n_heads=n_heads, dropout=dropout, mode=mode),
      core.Dense(d_feature),
  )

In particular, what is the purpose of the three parallel dense layers? The input to this layer is q, k, v, mask. Why the q, k, v are put through a dense layer?

1

There are 1 answers

2
Jindřich On BEST ANSWER

This code snippet is an implementation of the equation on the top of page 5 of the Attention is all you need paper that introduced the Transformer models in 2017. The computation is illustrated in Figure 2 of the paper:

enter image description here

The hidden states get projection into h attention heads which do the scaled dot-product attention in parallel. The projection can be interpreted as extraction of information that is relevant for the head. Each head then does the probabilistic retrieval based on different (learned) criteria.