I have been trying to make a custom mask for targetted combinations of queries and keys for my MultiHeadAttention layer but can not figure out the way to use this layer masking.
Here is an example with a dummy dataset (batch size 1) :
key = tf.ones([1, 32 , 128])
mask = tf.concat([
tf.concat([tf.zeros([16 , 16]) , tf.zeros([16 , 16]) ] , 0) ,
tf.concat([tf.zeros([16 , 16]) , tf.ones([16 , 16]) ] , 0) ] , 1)
mask = mask[tf.newaxis, tf.newaxis, : , : ]
# key shape -> ( 1 , 32 , 128 )
# mask shape -> ( 1 , 1, 32 , 32 )
when I print mask[0][0].numpy()
I get :
Now using the foolowing layer ( 1 head , self-attention ) :
mha_layer = tf.keras.layers.MultiHeadAttention( num_heads=1, key_dim=128 )
attention_output, attention_scores = mha_layer( key , key , attention_mask=mask , return_attention_scores=True)
I get the folowing attention scores (attention_scores[0][0].numpy()
) :
Here the dark-violet color stands for 0.0 , yellow for 0.06 and green-blue for 0.03
I would expect to have expected the green-blue part to be 0.0s because of the masking.
Am I using the masking wrong ? or it is not possible to mask entire queries/keys ?
I hope my question makes sense and that it is not too obvious. Thank you in advance, if you can help :)