I am trying to implement CGAN with convolutions. I have written a discriminator. The code is running but I am not sure if it is correct or not. Below is my code
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from inception_score import inception_score
f = None
# Define generator network
class Generator(nn.Module):
def __init__(self, latent_dim, num_classes):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.label_emb = nn.Embedding(num_classes, num_classes)
self.fc1 = nn.Linear(latent_dim + num_classes, 256 * 7 * 7)
self.conv1 = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.conv3 = nn.Sequential(
nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Tanh()
)
def forward(self, noise, labels):
gen_input = torch.cat((self.label_emb(labels), noise), -1)
x = self.fc1(gen_input)
x = x.view(x.shape[0], 256, 7, 7)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
# Define discriminator network
class Discriminator(nn.Module):
def __init__(self, num_classes):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.label_emb = nn.Embedding(num_classes, num_classes)
self.conv1 = nn.Sequential(
nn.Conv2d(3 + num_classes, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(256 * 4 * 4, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
label_emb = self.label_emb(labels) # shape: (batch_size, num_classes)
label_emb = label_emb.view(label_emb.size(0), label_emb.size(1), 1, 1) # shape: (batch_size, num_classes, 1, 1)
label_emb = label_emb.expand(-1, -1, img.size(2), img.size(3)) # shape: (batch_size, num_classes, img_height, img_width)
dis_input = torch.cat((img, label_emb), dim=1) # shape: (batch_size, 1 + num_classes, img_height, img_width)
x = self.conv1(dis_input)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
# Define the training function for CGAN
def train_CGAN(generator, discriminator, data_loader, num_epochs=200):
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
fixed_noise = torch.randn(10, generator.latent_dim)
fixed_labels = torch.arange(10).repeat(1, 1).transpose(1, 0).contiguous().view(-1, 1)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(data_loader):
batch_size = images.size(0)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train discriminator
real_images = images.to(device)
real_labels = labels.to(device)
fake_labels = torch.randint(low=0, high=10, size=(batch_size,)).to(device)
fake_images = generator(torch.randn(batch_size, generator.latent_dim).to(device), fake_labels)
real_outputs = discriminator(real_images, real_labels)
fake_outputs = discriminator(fake_images.detach(), fake_labels)
d_loss_real = criterion(real_outputs, torch.ones_like(real_outputs))
d_loss_fake = criterion(fake_outputs, torch.zeros_like(fake_outputs))
d_loss = d_loss_real + d_loss_fake
discriminator.zero_grad()
d_loss.backward()
optimizer_d.step()
# Train generator
fake_labels = torch.randint(low=0, high=10, size=(batch_size,)).to(device)
fake_images = generator(torch.randn(batch_size, generator.latent_dim).to(device), fake_labels)
fake_outputs = discriminator(fake_images, fake_labels)
g_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))
generator.zero_grad()
g_loss.backward()
optimizer_g.step()
is_score, is_std = inception_score(fake_images, cuda=True, batch_size=32, resize=True, splits=10)
f.write(f"{epoch},{is_score}, {is_std}\n")
print('Epoch [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'.format(epoch+1, num_epochs, d_loss.item(), g_loss.item()))
# Save generated images
if (epoch+1) % 10 == 0:
fake_images = generator(fixed_noise.to(device), fixed_labels.to(device))
save_image(fake_images.data, 'cgan_images/{}_{}.png'.format(epoch+1, i+1), nrow=10)
if __name__ == '__main__':
# Set random seed for reproducibility
torch.manual_seed(42)
# Define transformation to normalize images
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])])
# Download and load CIFAR10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Initialize generator and discriminator networks
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator(latent_dim=100, num_classes=10).to(device)
discriminator = Discriminator(num_classes=10).to(device)
f = open("cgan cifar10 results.csv", "w")
# Train CGAN
train_CGAN(generator, discriminator, train_loader, num_epochs=200)
f.close()
# Save trained models
torch.save(generator.state_dict(), 'cgan_generator.pth')
torch.save(discriminator.state_dict(), 'cgan_discriminator.pth')
I am particularly converned about my discriminator code
# Define discriminator network
class Discriminator(nn.Module):
def __init__(self, num_classes):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.label_emb = nn.Embedding(num_classes, num_classes)
self.conv1 = nn.Sequential(
nn.Conv2d(3 + num_classes, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(256 * 4 * 4, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
label_emb = self.label_emb(labels) # shape: (batch_size, num_classes)
label_emb = label_emb.view(label_emb.size(0), label_emb.size(1), 1, 1) # shape: (batch_size, num_classes, 1, 1)
label_emb = label_emb.expand(-1, -1, img.size(2), img.size(3)) # shape: (batch_size, num_classes, img_height, img_width)
dis_input = torch.cat((img, label_emb), dim=1) # shape: (batch_size, 1 + num_classes, img_height, img_width)
x = self.conv1(dis_input)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
I am on the right path?