Training transformer in pytorch train loss reduces but fails in validation. I have taken only one example just to test whether model architecture is fine or not. I have tested it on a larger dataset and still have the same results.
I guess there is some problem with decoder greedy decoding or tgt mask.
import torch
import torch.nn as nn
import random
import math
random.seed(1)
eng = 'this is transformer'
hin = 'यह ट्रांसफार्मर है'
eng_dict = {key:val for val,key in enumerate(set(eng),start=1)}
hin_dict = {key:val for val,key in enumerate(set(hin),start=1)}
class PositionalEncoding1D(nn.Module):
def __init__(self, d_model, max_len=100):
super().__init__()
pe = torch.zeros((max_len, d_model), requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
_, T, _ = x.shape
return x + self.pe[:, :T]
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.vocab_size = len(list(eng_dict.keys())) + 1
self.encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(128,4,512,0.1),6)
self.embeddings = nn.Embedding(self.vocab_size,embedding_dim = 128)
self.posEmb = PositionalEncoding1D(128,64)
def forward(self,src):
src = self.posEmb(self.embeddings(src))
x = self.encoder(src)
return x
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.vocab_size = len(list(hin_dict.keys())) + 2
self.decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(128,4,512,0.1,batch_first=True),6)
self.embeddings = nn.Embedding(self.vocab_size,embedding_dim = 128)
self.posEmb = PositionalEncoding1D(128,64)
self.lin = nn.Linear(128,self.vocab_size)
def forward_train(self,memory,tgt):
tgt = self.posEmb(self.embeddings(tgt))
tgt_mask = torch.triu(torch.ones(20,20),diagonal=1)
x = self.decoder(tgt,memory,tgt_mask)
x = self.lin(x)
return x
def forward_infer(self,memory):
tgt = self.posEmb(self.embeddings(torch.tensor([[0]],dtype=torch.int32)))
preds = []
for i in range(20):
x = self.decoder(tgt,memory)
x = self.lin(x[:,-1,:]) # greedy
_,pr = torch.max(x,dim=1)
preds.append(pr)
tgt = torch.cat([tgt ,
self.posEmb.pe[:, len(preds)] + (self.embeddings(pr) * math.sqrt(128)).unsqueeze(0) ],dim=1)
return torch.tensor(preds)
def forward(self,memory,tgt):
if tgt is not None:
x = self.forward_train(memory,torch.tensor(tgt,dtype=torch.int).unsqueeze(0))
else:
x = self.forward_infer(memory=memory)
return x
class EncDec(nn.Module):
def __init__(self):
super().__init__()
self.enc = Encoder()
self.dec = Decoder()
def forward(self,src,target=None):
x = self.enc(torch.tensor(src,dtype=torch.int).unsqueeze(0))
y = self.dec.forward(x,target)
return y
def tokenize(text,lang):
if 'hin' in lang:
lst = [hin_dict[i] for i in text]
lst = [0] + lst + [13] # sos and eos
else:
lst = [eng_dict[i] for i in text]
return lst
loss = torch.nn.CrossEntropyLoss(ignore_index=0)
optim = torch.optim.AdamW(model.parameters(),lr =0.0001)
for epoch in range(2000):
for steps in range(1):
out = model(tokenized_eng,tokenized_hin)
optim.zero_grad()
criterion = loss(out.reshape(-1,14),torch.tensor(tokenized_hin,dtype=torch.long))
criterion.backward()
optim.step()
if epoch % 100==0:
print(criterion)
model.eval()
x = model(tokenized_eng)
print('The output is = ')
print(x)
The output is = tensor([1, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
tensor(2.7242, grad_fn=<NllLossBackward0>)
tensor(0.0548, grad_fn=<NllLossBackward0>)
tensor(0.0243, grad_fn=<NllLossBackward0>)
tensor(0.0168, grad_fn=<NllLossBackward0>)
tensor(0.0122, grad_fn=<NllLossBackward0>)
tensor(0.0098, grad_fn=<NllLossBackward0>)
tensor(0.0077, grad_fn=<NllLossBackward0>)
tensor(0.0066, grad_fn=<NllLossBackward0>)
tensor(0.0053, grad_fn=<NllLossBackward0>)
tensor(0.0046, grad_fn=<NllLossBackward0>)
tensor(0.0042, grad_fn=<NllLossBackward0>)
tensor(0.0036, grad_fn=<NllLossBackward0>)
tensor(0.0034, grad_fn=<NllLossBackward0>)
tensor(0.0029, grad_fn=<NllLossBackward0>)
tensor(0.0026, grad_fn=<NllLossBackward0>)
tensor(0.0023, grad_fn=<NllLossBackward0>)
tensor(0.0020, grad_fn=<NllLossBackward0>)
tensor(0.0019, grad_fn=<NllLossBackward0>)
tensor(0.0017, grad_fn=<NllLossBackward0>)
tensor(0.0016, grad_fn=<NllLossBackward0>)
The output is = tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])