Can't apply same transform to image and mask for data-augmentation

126 views Asked by At

I'm trying to train a U-Net model build with pytorch. For that case, I built the dataset and applied transformations for data augmentation in both image and mask. The situation is that i want to apply the same transformation to both, that meaning, if I rotate the image by an amount of degrees I want the mask to be rotated the same amount of degrees and therein lies my problem. The image and the mask aren“t rotated by the same amount.

I leave the code bellow:

Dataset

import torch
from torch.utils.data import Dataset
import os

class INBreastDataset2012(Dataset):
    def __init__(self, dict_dir, transform=None):
        self.dict_dir = dict_dir
        self.data = os.listdir(self.dict_dir)
        self.transform = transform



    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        dict_path = os.path.join(self.dict_dir, self.data[index])
        patient_dict = torch.load(dict_path)
        image = patient_dict['image'].unsqueeze(0)
        mass_mask = patient_dict['mass_mask'].unsqueeze(0)
        mass_mask[mass_mask > 1.0] = 1.0


        if self.transform is not None:
            image = self.transform(image)
            mass_mask = self.transform(mass_mask)
            
        
        return image, mass_mask


"Trainging"(isn't really training at this point, just visualization of the information brought by the dataloader)

from dataset import INBreastDataset2012
from torchvision.transforms import v2 as T
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

train_dir = r'directory\of\training images and masks'
test_dir = r'directory\of\testing images and masks'

train_transform = T.Compose(
        [
            T.RandomRotation(degrees=35, expand=True, fill=255.0),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),

        ]
    )

train_data = INBreastDataset2012(train_dir,transform=train_transform)
test_data = INBreastDataset2012(test_dir)

train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

plt.figure(figsize=(12,12))
for i, (imagen,mascara) in enumerate(train_dataloader):
    ax = plt.subplot(2,4,i+1)
    ax.title.set_text(f'imagen {i+1}')
    plt.imshow(imagen.squeeze(), cmap='gray')
    ax = plt.subplot(2,4,i+3)
    ax.title.set_text(f'mascara de imagen {i+1}')
    plt.imshow(mascara.squeeze(), cmap='gray')
    if i == 1:
        break

Result Result transformation of images and masks

I will also add that I've tried with albumentations and torchvision.transforms v1. In examples of pytorch and youtube videos they seem to be doing the same as me.

I someone could help me to see what I'm doing wrong or have a solution to ensuring that the transformations are the same is going to be greatly appreciated.

If any extra information is needen please ask. Is my first post so I may have missed something. Thank you in advance

2

There are 2 answers

2
Muhammed Yunus On BEST ANSWER

You could try concatenating the image and mask along the channel dimension, running the transform, and then splitting the result back into two tensors. Below assumes the image and mask are shaped channels x height x width.

...

if self.transform is not None:
    #Concatenate along channel dimension.
    # Assuming dim=0 is the channel dimension (not the batch dim)
    image_and_mask = torch.cat([image, mask], dim=0) 
 
    #Transform together
    transformed = self.transform(image_and_mask)
    
    #Slice the tensors out
    image = transformed[:image.shape[0], ...]
    mass_mask = transformed[image.shape[0]:, ...]

...
1
Karl On

You should look at the documentation section on functional transforms which allow you to specify transform variables to deal with this exact issue.

import torchvision.transforms.functional as TF
import random

def my_segmentation_transforms(image, segmentation):
    if random.random() > 0.5:
        angle = random.randint(-30, 30)
        image = TF.rotate(image, angle) # rotate image
        segmentation = TF.rotate(segmentation, angle) # rotate segmentation mask with same angle
    # more transforms ...
    return image, segmentation