Implementation of the discriminator in CGAN

141 views Asked by At

I am trying to implement CGAN with convolutions. I have written a discriminator. The code is running but I am not sure if it is correct or not. Below is my code

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from inception_score import inception_score

f = None

# Define generator network
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.fc1 = nn.Linear(latent_dim + num_classes, 256 * 7 * 7)
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        x = self.fc1(gen_input)
        x = x.view(x.shape[0], 256, 7, 7)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

# Define discriminator network
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.num_classes = num_classes

        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3 + num_classes, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(256 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_emb = self.label_emb(labels)  # shape: (batch_size, num_classes)
        label_emb = label_emb.view(label_emb.size(0), label_emb.size(1), 1, 1)  # shape: (batch_size, num_classes, 1, 1)
        label_emb = label_emb.expand(-1, -1, img.size(2), img.size(3))  # shape: (batch_size, num_classes, img_height, img_width)
        dis_input = torch.cat((img, label_emb), dim=1)  # shape: (batch_size, 1 + num_classes, img_height, img_width)
        x = self.conv1(dis_input)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

# Define the training function for CGAN
def train_CGAN(generator, discriminator, data_loader, num_epochs=200):
    criterion = nn.BCELoss()
    optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    fixed_noise = torch.randn(10, generator.latent_dim)
    fixed_labels = torch.arange(10).repeat(1, 1).transpose(1, 0).contiguous().view(-1, 1)

    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(data_loader):
            batch_size = images.size(0)



            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            # Train discriminator
            real_images = images.to(device)
            real_labels = labels.to(device)
            fake_labels = torch.randint(low=0, high=10, size=(batch_size,)).to(device)
            fake_images = generator(torch.randn(batch_size, generator.latent_dim).to(device), fake_labels)

            real_outputs = discriminator(real_images, real_labels)
            fake_outputs = discriminator(fake_images.detach(), fake_labels)

            d_loss_real = criterion(real_outputs, torch.ones_like(real_outputs))
            d_loss_fake = criterion(fake_outputs, torch.zeros_like(fake_outputs))
            d_loss = d_loss_real + d_loss_fake

            discriminator.zero_grad()
            d_loss.backward()
            optimizer_d.step()

            # Train generator
            fake_labels = torch.randint(low=0, high=10, size=(batch_size,)).to(device)
            fake_images = generator(torch.randn(batch_size, generator.latent_dim).to(device), fake_labels)
            fake_outputs = discriminator(fake_images, fake_labels)
            g_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))

            generator.zero_grad()
            g_loss.backward()
            optimizer_g.step()

            is_score, is_std = inception_score(fake_images, cuda=True, batch_size=32, resize=True, splits=10)
            f.write(f"{epoch},{is_score}, {is_std}\n")

        print('Epoch [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))

        # Save generated images
        if (epoch+1) % 10 == 0:
            fake_images = generator(fixed_noise.to(device), fixed_labels.to(device))
            save_image(fake_images.data, 'cgan_images/{}_{}.png'.format(epoch+1, i+1), nrow=10)



if __name__ == '__main__':
    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Define transformation to normalize images
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5])])

    # Download and load CIFAR10 dataset
    train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

    # Initialize generator and discriminator networks
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    generator = Generator(latent_dim=100, num_classes=10).to(device)
    discriminator = Discriminator(num_classes=10).to(device)

    f = open("cgan cifar10 results.csv", "w")
    # Train CGAN
    train_CGAN(generator, discriminator, train_loader, num_epochs=200)
    f.close()

    # Save trained models
    torch.save(generator.state_dict(), 'cgan_generator.pth')
    torch.save(discriminator.state_dict(), 'cgan_discriminator.pth')

I am particularly converned about my discriminator code

# Define discriminator network
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super(Discriminator, self).__init__()
        self.num_classes = num_classes

        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3 + num_classes, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.fc = nn.Sequential(
            nn.Linear(256 * 4 * 4, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        label_emb = self.label_emb(labels)  # shape: (batch_size, num_classes)
        label_emb = label_emb.view(label_emb.size(0), label_emb.size(1), 1, 1)  # shape: (batch_size, num_classes, 1, 1)
        label_emb = label_emb.expand(-1, -1, img.size(2), img.size(3))  # shape: (batch_size, num_classes, img_height, img_width)
        dis_input = torch.cat((img, label_emb), dim=1)  # shape: (batch_size, 1 + num_classes, img_height, img_width)
        x = self.conv1(dis_input)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

I am on the right path?

0

There are 0 answers