Linked Questions

Popular Questions

I'm trying to train/fine-tune the MobileNetV3 model on the CIFAR100 dataset. I'm using Pytorch and Huggingface to simplify things like training loop which I'm not used to having to do manually coming from Tensorflow/Keras.

However, I get an error when trying to apply the pre-trained weights' transform (preprocessing) to the dataset with the .with_transform() method. I wanted to visualize the preprocessing partly to see that it actually works.

If I apply the preprocessing manually on an image it works, but if I used the with_transform method, I get an error.

Minimum reproducible example:

import torch
from torchvision.models import MobileNet_V3_Small_Weights
from datasets import load_dataset
from matplotlib import pyplot as plt

weights = MobileNet_V3_Small_Weights.DEFAULT
preprocess = weights.transforms()

raw_data = load_dataset("cifar100")
data = raw_data.with_transform(preprocess)

raw_img = raw_data["train"][0]["img"]

fig, axes = plt.subplots(1, 3)

axes[0].imshow(raw_img)
axes[0].set_title("Raw image")

img = preprocess(raw_img).permute(1, 2, 0)    # <----- Applying preprocessing "manually" on image works
axes[1].imshow(img)
axes[1].set_title("Preprocessed image (manual)")

img = data["train"][0]["img"]                 # <----- Getting image from preprocessed dataset doesn't work (preprocessing is lazy)
axes[2].imshow(img.permute(1, 2, 0))
axes[2].set_title("Preprocessed image (dataset)")

plt.show()

The error I get is:

Traceback (most recent call last):
  File "C:\Users\thiba\OneDrive - McGill University\Internship\ECSE301\pytorch_test.py", line 23, in <module>
    img = data["train"][0]["img"]
          ~~~~~~~~~~~~~^^^
  File "C:\Python311\Lib\site-packages\datasets\arrow_dataset.py", line 2778, in __getitem__
    return self._getitem(key)
           ^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\datasets\arrow_dataset.py", line 2763, in _getitem
    formatted_output = format_table(
                       ^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\datasets\formatting\formatting.py", line 624, in format_table
    return formatter(pa_table, query_type=query_type)
    return self.format_row(pa_table)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\datasets\formatting\formatting.py", line 480, in format_row
    formatted_batch = self.format_batch(pa_table)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\datasets\formatting\formatting.py", line 510, in format_batch
    return self.transform(batch)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torchvision\transforms\_presets.py", line 58, in forward
    img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torchvision\transforms\functional.py", line 476, in resize
    _, image_height, image_width = get_dimensions(img)
                                   ^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torchvision\transforms\functional.py", line 78, in get_dimensions
    return F_pil.get_dimensions(img)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python311\Lib\site-packages\torchvision\transforms\_functional_pil.py", line 31, in get_dimensions
    raise TypeError(f"Unexpected type {type(img)}")
TypeError: Unexpected type <class 'dict'>

Related Questions