Have I implemented self-attention correctly in Pytorch?

532 views Asked by At

This is my attempt at implementing self-attention using PyTorch. Have I done anything wrong, or could it be improved somehow?

class SelfAttention(nn.Module):
    def __init__(self, embedding_dim):
        super(SelfAttention, self).__init__()

        self.keys = nn.Linear(embedding_dim, embedding_dim)
        self.queries = nn.Linear(embedding_dim, embedding_dim)
        self.values = nn.Linear(embedding_dim, embedding_dim)

    
    def forward(self, x):
        keys = self.keys(x)
        queries = self.queries(x)
        values = self.values(x)
        
        scores_prime = torch.matmul(queries.T, keys)
        scores = nn.functional.softmax(scores_prime)

        context_vectors = torch.matmul(values, scores)

        return context_vectors

My test vector ran through without error, but I can't be sure I didn't make a mistake.

1

There are 1 answers

0
Shai On

To better test your implementation, I suggest you use a different dimension for the queries and keys. I think you replaced the roles of queries and keys.