I'm trying to train/fine-tune the MobileNetV3 model on the CIFAR100 dataset. I'm using Pytorch and Huggingface to simplify things like training loop which I'm not used to having to do manually coming from Tensorflow/Keras.
However, I get an error when trying to apply the pre-trained weights' transform (preprocessing) to the dataset with the .with_transform()
method. I wanted to visualize the preprocessing partly to see that it actually works.
If I apply the preprocessing manually on an image it works, but if I used the with_transform method, I get an error.
Minimum reproducible example:
import torch
from torchvision.models import MobileNet_V3_Small_Weights
from datasets import load_dataset
from matplotlib import pyplot as plt
weights = MobileNet_V3_Small_Weights.DEFAULT
preprocess = weights.transforms()
raw_data = load_dataset("cifar100")
data = raw_data.with_transform(preprocess)
raw_img = raw_data["train"][0]["img"]
fig, axes = plt.subplots(1, 3)
axes[0].imshow(raw_img)
axes[0].set_title("Raw image")
img = preprocess(raw_img).permute(1, 2, 0) # <----- Applying preprocessing "manually" on image works
axes[1].imshow(img)
axes[1].set_title("Preprocessed image (manual)")
img = data["train"][0]["img"] # <----- Getting image from preprocessed dataset doesn't work (preprocessing is lazy)
axes[2].imshow(img.permute(1, 2, 0))
axes[2].set_title("Preprocessed image (dataset)")
plt.show()
The error I get is:
Traceback (most recent call last):
File "C:\Users\thiba\OneDrive - McGill University\Internship\ECSE301\pytorch_test.py", line 23, in <module>
img = data["train"][0]["img"]
~~~~~~~~~~~~~^^^
File "C:\Python311\Lib\site-packages\datasets\arrow_dataset.py", line 2778, in __getitem__
return self._getitem(key)
^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\datasets\arrow_dataset.py", line 2763, in _getitem
formatted_output = format_table(
^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\datasets\formatting\formatting.py", line 624, in format_table
return formatter(pa_table, query_type=query_type)
return self.format_row(pa_table)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\datasets\formatting\formatting.py", line 480, in format_row
formatted_batch = self.format_batch(pa_table)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\datasets\formatting\formatting.py", line 510, in format_batch
return self.transform(batch)
^^^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\torchvision\transforms\_presets.py", line 58, in forward
img = F.resize(img, self.resize_size, interpolation=self.interpolation, antialias=self.antialias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\torchvision\transforms\functional.py", line 476, in resize
_, image_height, image_width = get_dimensions(img)
^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\torchvision\transforms\functional.py", line 78, in get_dimensions
return F_pil.get_dimensions(img)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "C:\Python311\Lib\site-packages\torchvision\transforms\_functional_pil.py", line 31, in get_dimensions
raise TypeError(f"Unexpected type {type(img)}")
TypeError: Unexpected type <class 'dict'>