I am new to Pytorch and trying to implement ViT on a spectrograms of raw audio . My training input consists of tensors [1,80,128] (almost 1M) of them and I am exploring AMP to speed up my training on a V100(16GB).
My training loop is as below
scaler = torch.cuda.amp.GradScaler(enabled = True)
for e in range(config_pytorch.epochs):
for idx,train_bat in enumerate(train_dl):
with autocast(enabled=True):
y_pred = model(x).float()
loss = criterion(y_pred, y.float())
scaler.scale(loss).backward()
train_loss += loss.detach().item()
scaler.step(optimiser)
scaler.update()
optimiser.zero_grad()
I print out the losses at each step just to check their values and they are very small (~1e-5) and after a few steps the loss becomes (0) . The code errors out with the following AssertionError: No inf checks were recorded prior to update .
The entire stack trace is as below.
AssertionError Traceback (most recent call last)
/tmp/ipykernel_972350/3829185638.py in <module>
----> 1 model = train_model_ast(train_dl , val_dl )
/tmp/ipykernel_972350/3546603516.py in train_model_ast(train_dl, val_dl, model)
130 bat_duration = bat_finish_time - start_time
131 print("&&&& BATCH TRAIN DURATION = " + str(bat_duration/60))
--> 132 scaler.update()
133 #removing all instances of 999
134
/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py in update(self, new_scale)
384 for found_inf in state["found_inf_per_device"].values()]
385
--> 386 assert len(found_infs) > 0, "No inf checks were recorded prior to update."
387
388 found_inf_combined = found_infs[0]
AssertionError: No inf checks were recorded prior to update.
The code however runs without any issues if I don’t use AMP.Appreciate if anyone could provide any pointers.
Thanks in advance..