Torchmetrics Frechet Inception Distance Weird Behaviour

248 views Asked by At

I am trying to create an FID to measure the performance of my generative models on MNIST.

I provide my own feature extractor.

However, in order to find the output dimension of the feature extractor you provide, torchmetrics tries to pass it a dummy image to see what dimension it outputs.

The problems is that the dummy image they generate does not follow the shape or date type my feature extractor expects.

There is no way for me to manually specifiy the dummy image that should be passed in, so I can't control that.

Here is an example of what I'm trying to do:

N = <appropriate number>

class SimpleConvFeatureExtractor(nn.Module):
  def __init__(self, embed_dim):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=2)
    self.out = nn.Sequential(nn.Linear(N, embed_dim))

  def forward(self, x):
    return th.randn(size=(1, 128))
    print(x.shape)
    print(x.dtype)
    x = F.silu(self.conv1(x))
    x = self.out(x.view(x.shape[0], -1))
    return x

fid = FrechetInceptionDistance(feature=SimpleConvFeatureExtractor(128))

with output

torch.Size([1, 3, 299, 299]) torch.uint8 RuntimeError: Input type (unsigned char) and bias type (float) should be the same

As you can see the image being passed through is hardly an MNIST image.

2

There are 2 answers

0
dbel On

I had a similar error with a project of mine. I wanted to see if anyone else would be able to answer your post, but given the silence, I will give my best attempt at an answer! For me the solution lay in the class definitions. When you create your class and define __init__ you should try to pass in a transform which will make its input a tensor.

If you want to see the similarity between our issues you can check out my question here.

0
Donshel On

I don't know how to modify torchmetrics, but you can use the following function to compute the Fréchet distance from your features' mean and covariance.

def frechet_distance(mu_x: Tensor, sigma_x: Tensor, mu_y: Tensor, sigma_y: Tensor) -> Tensor:
    a = (mu_x - mu_y).square().sum(dim=-1)
    b = sigma_x.trace() + sigma_y.trace()
    c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum(dim=-1)

    return a + b - 2 * c

This is the implementation of PIQA, an image quality assessment (IQA) package. You can also directly use PIQA's implementation:

import piqa
fid = piqa.FID()
fid(x_feats, y_feats)

Disclaimer: I am the author of PIQA.