Inputing a grayscale image to a 3 channel torch model using dataloaders

181 views Asked by At

I've trained a Convolutional VAE using pytorch without any errors, and I'm getting reasonable reconstructed images. But I'm having a problem now that I don't fully understand.

When I load an image that was in my training set and input it throuh the model, it is able to get me its latent space representation and reconstruct it without any trouble, except if it is a single channel image.

What I can't understand, is how it didn't get me any errors in training. If it cannot encode a single channel image in inference, how could it train with it?

# This is my "inference", it uses the same images from training.
# I just want to get the latent representation of them all.

import os
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from Autoencoder import ConvolutionalVariationalAutoEncoder as VAE


image_size = (256, 256)
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

images_path = "../data/images/"
images_names = os.listdir(images_path)

# out_channels, kernel_size, stride, padding
layer_params = [
    (16, 4, 2, 1),
    (32, 4, 2, 1),
    (64, 4, 2, 1),
]
input_shape = (3,) + image_size
latent_dim = 64

model = VAE(input_shape, layer_params, latent_dim)
model.load_state_dict(torch.load("../models/third.1_model.pth"))
model.eval()

latent_representations = dict()
k = 0
for image_name in images_names:
    print(k)
    image = Image.open(images_path+image_name)
    image = transform(image).unsqueeze(0)

    print(images_path+image_name, image.shape)
    with torch.no_grad():
        image = image.to(model.device)
        mu, logvar = model.encode(image)
        latent_representations[image_name] = model.reparameterize(mu, logvar)
    k+= 1
latent_representations

This code runs for some iterations, prints this lines and launches this error when it gets to a single channel image:

0
../data/images/1.jpg torch.Size([1, 3, 256, 256])
1
../data/images/10.jpg torch.Size([1, 3, 256, 256])
2
../data/images/100.jpg torch.Size([1, 1, 256, 256])
RuntimeError: Given groups=1, weight of size [16, 3, 4, 4], expected input[1, 1, 256, 256] to have 3 channels, but got 1 channels instead

Just for reference, that is how I ran training:

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import transforms

from Autoencoder import ConvolutionalVariationalAutoEncoder as VAE
from Autoencoder import ImageDataset

image_size = (256, 256)

transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

image_folder_path = "../data/"
image_dataset = ImageDataset(image_folder_path, transform=transform)
data_loader = DataLoader(image_dataset, batch_size=32, shuffle=True)

# out_channels, kernel_size, stride, padding
layer_params = [
    (16, 4, 2, 1),
    (32, 4, 2, 1),
    (64, 4, 2, 1),
]

input_shape = (3,) + image_size
latent_dim = 64

vae = VAE(input_shape, layer_params, latent_dim)

vae.train_model(10, data_loader)

Where .train_model() is a method I implemented in the Autoencoder Classs:

def train_model(self, n_epochs, data_loader):
    for epoch in range(n_epochs):
        self.train()
        total_loss = 0
        for data in data_loader:
            data = data.to(self.device)
            recon_data, mu, logvar = self(data)
            loss = self.loss(recon_data, data, mu, logvar)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        average_loss = total_loss / len(data_loader.dataset)
        print(f'Epoch [{epoch + 1}/{n_epochs}], Loss: {average_loss:.4f}')

And as I said, it didn't throw any errors, just trained... How is it possible, should I worry it didn't use the grayscale images? And also, how do I fix my inference?

I mean, I can fix my inference in many ways, I can just check the number of channels and replicate them, my biggest concern is what happened during training and why it didn't cause any errors? I want my model to have trained in all my images.

EDIT: I've changed my inference to solve my imediate problem, the code is like this now:

for image_name in tqdm(images_names):
    image = Image.open(images_path+image_name)
    image = transform(image).unsqueeze(0)
    print(images_path+image_name, image.shape)
    with torch.no_grad():
        image = image.to(model.device)
        if image.shape[1] == 1:
            image = image.expand(-1, 3, -1, -1)
        mu, logvar = model.encode(image)
        latent_representations[image_name] = model.reparameterize(mu, logvar)
latent_representations

And I just found out I also have 4 channels images in my dataset hahah. How the heck didn't my training break?

751
../data/images/1682.jpg torch.Size([1, 3, 256, 256])
752
../data/images/1683.jpg torch.Size([1, 4, 256, 256])
RuntimeError: Given groups=1, weight of size [16, 3, 4, 4], expected input[1, 4, 256, 256] to have 3 channels, but got 4 channels instead
0

There are 0 answers