I am trying to replicate a paper, but I am having trouble with unstable training. More precisely, rerunning the code yields wildly different results. However, I think that there are 4 main patterns.
The model and the training are as shown below (full code at https://www.kaggle.com/code/adelphene/dagmm):
class DAGMM(nn.Module):
def __init__(self, input_dim=118, latent_dim=1, n_gmm=4):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 60), nn.Tanh(),
nn.Linear(60, 30), nn.Tanh(),
nn.Linear(30, 10), nn.Tanh(),
nn.Linear(10, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 10), nn.Tanh(),
nn.Linear(10, 30), nn.Tanh(),
nn.Linear(30, 60), nn.Tanh(),
nn.Linear(60, input_dim),
)
self.estimator = nn.Sequential(
nn.Linear(latent_dim + 2, 10), nn.Tanh(),
nn.Dropout(),
nn.Linear(10, n_gmm), nn.Softmax(dim=1)
)
def forward(self, x):
l = self.encoder(x)
r = self.decoder(l)
re = (x - r).norm(p=2, dim=1) / x.norm(p=2, dim=1)
cs = F.cosine_similarity(x, r, dim=1)
z = torch.cat((l, re.unsqueeze(-1), cs.unsqueeze(-1)), dim=1)
g = self.estimator(z)
return r, z, g
# ------------------------------------------------------
torch.autograd.set_detect_anomaly(False)
epochs = 200
lr = 1e-4
loss_fn = Loss()
batches = len(train_dataloader)
model = DAGMM(n_gmm=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
for batch, (x, _) in enumerate(train_dataloader):
r, z, g = model.train()(x)
loss, _ = loss_fn(x, r, z, g)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 400 == 0:
r, z, g = model.eval()(val_dataset.x)
val_loss, e = loss_fn(val_dataset.x, r, z, g)
threshold = np.percentile(e.detach().cpu(), 80)
y_pred = (e > threshold) * 1
y_true = val_dataset.y
report = classification_report(y_true.cpu(), y_pred.cpu(), output_dict=True)
a = round(report["accuracy"], 2)
p = round(report["macro avg"]["precision"], 2)
r = round(report["macro avg"]["recall"], 2)
print(f"loss: {round(loss.item(), 3)}, accuracy: {a}, precision: {p}, recall: {r} [{batch+1}/{batches}] [{epoch+1}/{epochs}]")
To be clear, the number of epochs, the optimizer, and the learning rate are taken directly from the paper.
The patterns that I get are as follows:
- Loss decreases to 0.5/0.6/0.7, the accuracy is ~0.88, and precision and recall are both ~0.82
- Loss stuck at ~2.0, accuracy ~0.6, precision and recall ~0.44
- Loss between ~1.5 and ~0.8, accuracy between ~0.75 and ~0.85, precision and recall between ~0.65 and ~0.75
- (Best) Accuracy goes down to ~1.2/~0.8, accuracy goes up to ~0.94, and precision and recall to ~0.92
I am not exactly sure how I should approach this. Any guidance is appreciated!
PS: I believe that I observer pattern 2 more often than the others.