.backward() taking much longer when training a Siamese network

32 views Asked by At

I'm training a Siamese network for image classification and comparing to a baseline that didn't use a Siamese architecture. When not using the Siamese architecture each epoch takes around 17 minutes, but with the Siamese architecture each epoch is estimated to take ~5 hours. I narrowed down the problem to the .backward() function, which takes a few seconds when the Siamese network is being used.

This is part of the training loop for the non-Siamese network:

output = model(data1)
loss = criterion(output,target)
print("doing backward()")
grad_scaler.scale(loss).backward()
print("doing step()")
grad_scaler.step(optimizer)
print("doing update()")
grad_scaler.update()
print("done")

This is a part of the training loop of the Siamese network:

output1 = model(data1)
output2 = model(data2)
loss1 = criterion(output1, target)
loss2 = criterion(output2, target)
loss3 = criterion_mse(output1,output2)
loss = loss1 + loss2 + loss3
 
print("doing backward()")
grad_scaler.scale(loss).backward()
print("doing step()")
grad_scaler.step(optimizer)
print("doing update()")
grad_scaler.update()
print("done")

I narrowed down the problem to the .backward() function, which takes a few seconds when the Siamese network is being used. I actually want to reduce the training time and understand why is it taking so long?

0

There are 0 answers