I would like to use a custom feature extractor to calculate FID
according to https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html I can use nn.Module
for feature
What is wrong with the following code?
import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3
net = inception_v3()
checkpoint = torch.load('checkpoint.pt')
net.load_state_dict(checkpoint['state_dict'])
net.eval()
fid = FrechetInceptionDistance(feature=net)
# generate two slightly overlapping image intensity distributions
imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
fid.update(imgs_dist1, real=True)
fid.update(imgs_dist2, real=False)
result = fid.compute()
print(result)
Traceback (most recent call last):
File "foo.py", line 12, in <module>
fid = FrechetInceptionDistance(feature=net)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torchmetrics/image/fid.py", line 304, in __init__
num_features = self.inception(dummy_image).shape[-1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torchvision/models/inception.py", line 166, in forward
x, aux = self._forward(x)
^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torchvision/models/inception.py", line 105, in _forward
x = self.Conv2d_1a_3x3(x)
^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torchvision/models/inception.py", line 405, in forward
x = self.conv(x)
^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Lib/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected scalar type Byte but found Float
Process finished with exit code 1
The problem is you're casting your inputs to
dtype=torch.uint8
. The model expects a float tensor.