How to get padding mask for cross attention of decoder of transformer

152 views Asked by At

Is the padding mask for cross attention of the decoder of the transformer architecture only derived from the shape of the input tensors from the encoder, so the padding mask is found using this function:

def make_source_mask(source_ids, source_pad_id):
    return (source_ids != source_pad_id).unsqueeze(-2)

where source_ids is the encoder outputs into the decoder

so here is my complete code for the cross attention module:

class CrossMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(CrossMultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model, 2*d_model)
        self.q_layer = nn.Linear(d_model, d_model)
        self.linear_layer = nn.Linear(d_model, d_model)
    def forward(self, x, y, x_source_shape, PADDING_TOKEN):
        batch_size, max_sequence_length, d_model= x.size()
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, max_sequence_length, self.num_heads, self.head_dim*2)
        q = q.reshape(batch_size, max_sequence_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask=make_source_mask(x_source_shape, PADDING_TOKEN))
        values = values.permute(0, 2, 1, 3).reshape(batch_size, max_sequence_length, self.num_heads*self.head_dim)
        out = self.linear_layer(values)
        return out

is the source_mask found wrongly?

0

There are 0 answers