Confused about MultiHeadAttention output shapes (Tensorflow)

97 views Asked by At
random_time = np.random.random([1,1,100])
random_2 = np.random.random([1,100,1])
out_attention = MultiHeadAttention(
        num_heads = 4,
        key_dim = 1,
        #output_shape=400
    )
output = out_attention(random_2,random_2)
output = Dense(1,activation = 'ReLU')(output)
print(output)

Hi, I am playing around with the code above since I have been tasked with creating a transformer for 1D time-series data.

The issue is that out_attention(random_2,random_time), out_attention(random_time,random_2), out_attention(random_time,random_time) and out_attention(random_2,random_2) all give valid outputs but with different shape. The code above gives me a tensor with the shape (1, 400, 1). I would expect the correct code to give me a tensor with shape (1,1,1), and I can do that by choosing (random_time, random_time)... But, that doesn't seem to fit with my interpretation of the official documentation for the layer.

The Attention is All You Need paper also says:

In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically

This seems to imply that my implementation above is the right one for a transformer.

However, this seems to transfer to all layers downstream... So, every layer will have at least 100 tokens. I can potentially find somewhat of an explanation in the Tensorflow transformer tutorial. There it states:

Note: The model is optimized for efficient training and makes a next-token prediction for each token in the output simultaneously. This is redundant during inference, and only the last prediction is used. This model can be made more efficient for inference if you only calculate the last prediction when running in inference mode (training=False).

Their model seem to be using similar layers.. So, is that what I am supposed to do? Just ignore 99 of the outputs and focus on the last one? The Vaswani et al paper after all states that it produces tokens "one element at a time".

Finally, the output_shape parameter brings even more confusion; It seems to be able to reshape the output of the MultiHeadAttention into any shape....

Is there any good resource on how to understand this layer? Or, do I just have to dig through every line of the source code to figure out what it actually does..?

0

There are 0 answers