Mixed Precision Training: Loss Function Data Type Mismatch in PyTorch

62 views Asked by At

I'm working on a PyTorch-based denoising autoencoder and have implemented mixed precision training. My training loop includes a loss function, but I'm encountering a data type mismatch issue when calculating the loss. The model's outputs are autocasted to float32 (default), while the input images are in float16. Here's a simplified version of my code:

for i, (images, _) in tqdm(enumerate(training_loader, 0), total=len(training_loader)):
    images = images.to("cuda")
    if random.randint(0, 1) > 0.7:
        inputs = apply_image_transformations(images)
    else:
        inputs = images

    optimizer.zero_grad()

    # Forward pass
    with torch.autocast(device_type="cuda"):
        outputs = model(inputs)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()

The criterion here is an instance of my custom loss function, DenoisingAutoencoderLoss, which includes components of MSE, L1, and SSIM losses. The problem is that the outputs tensor is autocasted to float32, while the images are in float16, which leads to a data type mismatch during loss calculation.

How can I ensure consistent data types for outputs and images in mixed precision training? Should I cast outputs to float16 or images to float32? What's the best practice in this case?

Any guidance or recommendations on how to handle data type consistency during mixed precision training would be greatly appreciated. Thank you!

0

There are 0 answers