This is my attempt at implementing self-attention using PyTorch. Have I done anything wrong, or could it be improved somehow?
class SelfAttention(nn.Module):
def __init__(self, embedding_dim):
super(SelfAttention, self).__init__()
self.keys = nn.Linear(embedding_dim, embedding_dim)
self.queries = nn.Linear(embedding_dim, embedding_dim)
self.values = nn.Linear(embedding_dim, embedding_dim)
def forward(self, x):
keys = self.keys(x)
queries = self.queries(x)
values = self.values(x)
scores_prime = torch.matmul(queries.T, keys)
scores = nn.functional.softmax(scores_prime)
context_vectors = torch.matmul(values, scores)
return context_vectors
My test vector ran through without error, but I can't be sure I didn't make a mistake.
To better test your implementation, I suggest you use a different dimension for the queries and keys. I think you replaced the roles of queries and keys.