How to implement Text2Image with CNNs and Transposed CNNs

164 views Asked by At

I wanna implement text2image neural networks like the image below: Please see the image ![Image2text](https://drive.google.com/file/d/1A82iC29omu2yQrKEJrv1ropNtaL3urfH/view?usp=share_link) using CNNs and Transposed CNNs with Embedding layer

import torch
from torch import nn

Input text :

text = "A cat wearing glasses and playing the guitar "

# Simple preprocessing the text
word_to_ix = {"A": 0, "cat": 1, "wearing": 2, "glasses": 3, "and": 4, "playing": 5, "the": 6, "guitar":7}
lookup_tensor = torch.tensor(list(word_to_ix.values()), dtype = torch.long) # a tensor representing words by integers

vocab_size = len(lookup_tensor)

architecture implementation :

class TextToImage(nn.Module):
    def __init__(self, vocab_size):
        super(TextToImage, self).__init__()
        
        self.vocab_size = vocab_size
        self.noise = torch.rand((56,64))
        
        # DEFINE the layers
        # Embedding
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim = 64)
        
        # Conv
        self.conv2d_1 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding='valid')
        self.conv2d_2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding='valid')
        
        # Transposed CNNs
        self.conv2dTran_1 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=1)
        self.conv2dTran_2 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding=0)
        self.conv2dTran_3 = nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(4, 4), stride=(2, 2), padding=0)
        
        self.relu    = torch.nn.ReLU(inplace=False)
        self.dropout = torch.nn.Dropout(0.4)
        

    def forward(self, text_tensor):
        #SEND the input text tensor to the embedding layer
        emb = self.embed(text_tensor)
        
        #COMBINE the embedding with the noise tensor. Make it have 3 dimensions
        combine1 = torch.cat((emb, self.noise), dim=1, out=None)

        #SEND the noisy embedding to the convolutional and transposed convolutional layers
        conv2d_1 = self.conv2d_1(combine1)
        conv2d_2 = self.conv2d_2(conv2d_1)
        dropout  = self.dropout(conv2d_2)
                                               
        conv2dTran_1 = self.conv2dTran_1(dropout)
        conv2dTran_2 = self.conv2dTran_2(conv2dTran_1)
                                               
        #COMBINE the outputs having a skip connection in the image of the architecture
        combine2 = torch.cat((conv2d_1, conv2dTran_2), dim=1, out=None)
        conv2dTran_3 = self.conv2dTran_3(combine2)

        #SEND the combined outputs to the final layer. Please name your final output variable as "image" so you that it can be returned
        image = self.relu(conv2dTran_3)

        return image

Expected output torch.Size( [3, 64, 64] )

texttoimage = TextToImage(vocab_size=vocab_size)

output = texttoimage(lookup_tensor)

output.size()

Generated random noisy image :

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(np.moveaxis(output.detach().numpy(), 0,-1))

The error I got :

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 8 but got size 56 for tensor number 1 in the list.

Does anyone how to solve this issue I think it from concatenate nosey with embedding

2

There are 2 answers

0
Mohammed On BEST ANSWER

After changing dim = 0 and expand to 3 dim In addition there was issue in Input channel for first Conv_1 where changed from 64 to 1

class TextToImage(nn.Module):
    def __init__(self, vocab_size):
        super(TextToImage, self).__init__()
        
        self.vocab_size = vocab_size
        self.noise = torch.rand((56,64))
        
        # DEFINE the layers
        # Embedding
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim = 64)
        
        # Conv
        self.conv2d_1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding='valid')
        self.conv2d_2 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding='valid')
        
        # Transposed CNNs
        self.conv2dTran_1 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=1)
        self.conv2dTran_2 = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=(3, 3), stride=(2, 2), padding=0)
        self.conv2dTran_3 = nn.ConvTranspose2d(in_channels=6, out_channels=3, kernel_size=(4, 4), stride=(2, 2), padding=0)
        
        self.relu    = torch.nn.ReLU(inplace=False)
        self.dropout = torch.nn.Dropout(0.4)
        

    def forward(self, text_tensor):
        #SEND the input text tensor to the embedding layer
        emb = self.embed(text_tensor)
        
        #COMBINE the embedding with the noise tensor. Make it have 3 dimensions
        combined = torch.cat((emb, self.noise), dim=0) #, out=None
        print(combined.shape)
        combined_3d = combined[None, :]
        print(combined_3d.shape)   

        # SEND the noisy embedding to the convolutional and transposed convolutional layers
        conv2d_1 = self.conv2d_1(combined_3d)
        conv2d_2 = self.conv2d_2(conv2d_1)
        dropout  = self.dropout(conv2d_2)
                                               
        conv2dTran_1 = self.conv2dTran_1(dropout)
        conv2dTran_2 = self.conv2dTran_2(conv2dTran_1)
                                               
        #COMBINE the outputs having a skip connection in the image of the architecture
        combined_2 = torch.cat((conv2d_1, conv2dTran_2),axis=0) #dim=1, out=None
        conv2dTran_3 = self.conv2dTran_3(combined_2)

        #SEND the combined outputs to the final layer. Please name your final output variable as "image" so you that it can be returned
        image = self.relu(conv2dTran_3)

        return image
0
Matt Eng On

The cat function requires the tensor shapes to match aside from the dimension you're concatenating, so to concatenate (8,64) and (56,64) your cat should look like this using dim 0 instead of 1:

combine1 = torch.cat((emb, self.noise), dim=0, out=None)

After that, I'm not seeing where you give combine1 a 3rd dimension.