Is the padding mask for cross attention of the decoder of the transformer architecture only derived from the shape of the input tensors from the encoder, so the padding mask is found using this function:
def make_source_mask(source_ids, source_pad_id):
return (source_ids != source_pad_id).unsqueeze(-2)
where source_ids is the encoder outputs into the decoder
so here is my complete code for the cross attention module:
class CrossMultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(CrossMultiHeadAttention, self).__init__()
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.kv_layer = nn.Linear(d_model, 2*d_model)
self.q_layer = nn.Linear(d_model, d_model)
self.linear_layer = nn.Linear(d_model, d_model)
def forward(self, x, y, x_source_shape, PADDING_TOKEN):
batch_size, max_sequence_length, d_model= x.size()
kv = self.kv_layer(x)
q = self.q_layer(y)
kv = kv.reshape(batch_size, max_sequence_length, self.num_heads, self.head_dim*2)
q = q.reshape(batch_size, max_sequence_length, self.num_heads, self.head_dim)
kv = kv.permute(0, 2, 1, 3)
q = q.permute(0, 2, 1, 3)
k, v = kv.chunk(2, dim=-1)
values, attention = scaled_dot_product(q, k, v, mask=make_source_mask(x_source_shape, PADDING_TOKEN))
values = values.permute(0, 2, 1, 3).reshape(batch_size, max_sequence_length, self.num_heads*self.head_dim)
out = self.linear_layer(values)
return out
is the source_mask found wrongly?