How to prepare Multimodel data for Fine Tuning ViLT

121 views Asked by At

We need to fine tune the ViLT model for the UMPC-Food-101 dataset. The pre-trained processor for ImageandTextClassification has the following syntax:

encoding = processor([image1, image2], text, return_tensors="pt")

Initially, I worked with only the image data for fine tuning the ViT and I used the following method:

val_data = {'image': image_file_paths, 'label': multi_hot_labels}

ds_val = **Dataset.from_dict**(val_data)

def transform(examples):
  inputs = processor([pil.open(img).convert("RGB") for img in examples["image"]], return_tensors="pt")
  inputs["labels"] = examples["label"]
  return inputs

val_dataset = ds_val.with_transform(transform)

But now I can not use the Dataset.from_dict function as it doesn't support three lists. Currently I have the dictionary that has the following lists:

val_data = {
    'image': image_file_paths,
    'text': texts_csv_lst,
    'label': labels,
}
2

There are 2 answers

0
Sanad Bhowmik On
import torch
from torch.utils.data import Dataset

class CustomImageTextDataset(Dataset):
    def __init__(self, image_file_paths, texts_csv_lst, labels, processor):
        self.image_file_paths = image_file_paths
        self.texts_csv_lst = texts_csv_lst
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.image_file_paths)

    def __getitem__(self, idx):
        image_path = self.image_file_paths[idx]
        image = Image.open(image_path).convert("RGB")  # You may need to import PIL.Image as Image
        image_inputs = self.processor(images=image, text=self.texts_csv_lst[idx], return_tensors="pt")

        
        label = torch.tensor(self.labels[idx])

        return {
            "input_ids": image_inputs["input_ids"],
            "attention_mask": image_inputs["attention_mask"],
            "text": self.texts_csv_lst[idx],
            "labels": label
        }


custom_dataset = CustomImageTextDataset(image_file_paths, texts_csv_lst, labels, processor)
0
Muhammad Irzam On

The following worked for me:

import torch
from PIL import Image as pil

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, images, texts, labels, processor):
        self.images = images
        self.texts = texts
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.processor(pil.open(image).convert("RGB"), text,  is_split_into_words=False, padding="max_length", truncation=True, return_tensors="pt")
        for k,v in encoding.items():
          encoding[k] = v.squeeze()

        # encoding["labels"] = label
        encoding["labels"] = torch.tensor(label)

        return encoding