Triplet loss can't learn as the theory in text embedding

352 views Asked by At

I am working on a triplet loss based model for text embedding.
Short description:
I have a database about online shop, I need to find the suitble product when users enter a text on search bar. I want a model work better than matching string and can understand user's mind. I define a triplet Network like that: My input is (query text [anchor], next product user view after searching [positive], a random product [negative]). I build an encoder model based on bi-LSTM and tried to train the distance between anchor and positive is minimum and the distance between anchor and negative is maximun, and use triplet loss.
I tried to implement this network enter image description here
refer : https://arxiv.org/pdf/2104.08558.pdf
My encoderNet

class encodeNet(nn.Module):

def __init__(self, vocab_size, embedding_dim, hidden_dim, n_layers, 
             bidirectional, dropout):
    
    #Constructor
    super().__init__()          
    
    #embedding layer
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.directions = bidirectional
    #lstm layer
    self.lstm = nn.LSTM(embedding_dim, 
                       hidden_dim, 
                       num_layers=n_layers, 
                       bidirectional=bidirectional, 
                       dropout=dropout,
                       batch_first=True)
    
    self.fc1 = nn.Linear(hidden_dim * 2, 1024)
    self.fc2 = nn.Linear(1024, 512)
    self.fc3 = nn.Linear(512, 512)
    self.dropout = nn.Dropout(p=0.3)
    self.batchnorm1 = nn.BatchNorm1d(1024)
    self.batchnorm2 = nn.BatchNorm1d(512)
    self.relu = nn.ReLU()
    self.P1 = nn.MaxPool1d(2, stride=2)
    self.act = nn.Sigmoid()
    
def LM(self, text):
    embedded = self.embedding(text)       
    packed_output, (hidden, cell) = self.lstm(embedded)
    #concat the final forward and backward hidden state
    hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
    hidden = self.dropout(hidden)
    hidden = self.fc1(hidden)
    hidden = self.batchnorm1(hidden)  
    hidden = self.relu(hidden)
    hidden = self.fc2(hidden)       
    hidden = self.batchnorm2(hidden)  
    hidden = self.fc3(hidden)
    return hidden
def forward(self, anchor, pos, neg):
    anchor = self.LM(anchor)
    pos = self.LM(pos)
    neg = self.LM(neg)
    anchor = self.P1(anchor)
    pos = self.P1(pos)
    neg = self.P1(neg)
    return anchor, pos,neg

And I used loss function triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2) by pytorch framework.
The result, I saw that in training dataset, loss value decreased to so small and so fast but in valid dataset loss value didn't present any meaning, it was up and down like random.
I trained model with 8572 vocabs, 81822 training samples, Is it too small dataset?
Can you help me and what is the issue in my solution?

1

There are 1 answers

0
Reza Tanakizadeh On

I suggest you to use Hard-Triplets. You can learn more about this in FaceNet paper. I hope that it can be helpful for you.