Scaler.update() - AssertionError: No inf checks were recorded prior to update

1.4k views Asked by At

I am new to Pytorch and trying to implement ViT on a spectrograms of raw audio . My training input consists of tensors [1,80,128] (almost 1M) of them and I am exploring AMP to speed up my training on a V100(16GB).

My training loop is as below

scaler = torch.cuda.amp.GradScaler(enabled = True)
for e in range(config_pytorch.epochs):
    for idx,train_bat in enumerate(train_dl):
           with autocast(enabled=True):
                 y_pred = model(x).float()
                 loss = criterion(y_pred, y.float())
                 scaler.scale(loss).backward()
                  train_loss += loss.detach().item()
          scaler.step(optimiser)
          scaler.update()
          optimiser.zero_grad()

I print out the losses at each step just to check their values and they are very small (~1e-5) and after a few steps the loss becomes (0) . The code errors out with the following AssertionError: No inf checks were recorded prior to update .

The entire stack trace is as below.

AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_972350/3829185638.py in <module>
----> 1 model = train_model_ast(train_dl , val_dl )

/tmp/ipykernel_972350/3546603516.py in train_model_ast(train_dl, val_dl, model)
    130             bat_duration = bat_finish_time - start_time
    131             print("&&&& BATCH TRAIN DURATION = " + str(bat_duration/60))
--> 132             scaler.update()
    133             #removing all instances of 999
    134 

/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py in update(self, new_scale)
    384                           for found_inf in state["found_inf_per_device"].values()]
    385 
--> 386             assert len(found_infs) > 0, "No inf checks were recorded prior to update."
    387 
    388             found_inf_combined = found_infs[0]

AssertionError: No inf checks were recorded prior to update.

The code however runs without any issues if I don’t use AMP.Appreciate if anyone could provide any pointers.

Thanks in advance..

0

There are 0 answers