I am working on a problem wherein I am trying to improve the quality of generated images in cGAN using transformer. What basically I am trying to acheive is that I have a basic generator and discriminator as under:
# Generator
`class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# Define the generator layers
self.label_embedding = nn.Embedding(10, latent_dim) # 10 classes for MNIST
self.main = nn.Sequential(
nn.Linear(latent_dim * 2, hidden_dim),
#nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 784), # 28x28 image size
nn.Tanh()
)
def forward(self, noise, class_label):
# Forward pass for the generator
class_embedding = self.label_embedding(class_label)
combined_input = torch.cat((noise, class_embedding), dim=1)
return self.main(combined_input).view(noise.size(0), 1, 28, 28)
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# Define the discriminator layers
self.label_embedding = nn.Embedding(10, 784) # 10 classes for MNIST
self.main = nn.Sequential(
nn.Linear(784 * 2, hidden_dim),
#nn.BatchNorm1d(hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, image, class_label):
# Forward pass for the discriminator
class_embedding = self.label_embedding(class_label)
class_embedding = class_embedding.view(class_embedding.size(0), -1)
image = image.view(image.size(0), -1)
combined_input = torch.cat((image, class_embedding), dim=1)
return self.main(combined_input)`
now I also have a pretrained transformer which i am using to improve the quality of generated images .What I am utilizing for this is the attention probe(first row of attention map of transformer ie class attention to all the tokens) I am calculating the attention probe of the generated images then taking the loss between the generaed attention probe and the class attention probe( average of attention probe of all the images in the training dataset of a particular class) and then using this loss along with adversal loss to train the generator as under:
attn_data = cal_attnprime_blk7_batchwise(checkpoint, embed_dim, num_heads, x_fake) # attention probe of generated images attention_loss1 = patch_attention_probe_loss(class_avg_attnprime_batch,attn_data_tensor) # loss total_g_loss = 0.2 *g_loss + 0.8 *attention_loss1 # total loss
here are the function for cal_attnprime_blk7_batchwise and patch_attention_probe_loss
def patch_attention_probe_loss(feature_T, feature_S):
B= feature_T.shape[0]
M = feature_T.shape[1]
N = feature_T.shape[2]
feature_T_norm = F.normalize(feature_T, p=2, dim=1)
feature_T_norm = F.normalize(feature_T_norm, p=2, dim=2)
feature_S_norm = F.normalize(feature_S, p=2, dim=1)
feature_S_norm = F.normalize(feature_S_norm, p=2, dim=2)
patch_attn_diff = feature_T_norm - feature_S_norm
patch_attn_loss = (patch_attn_diff * patch_attn_diff).sum() / (B* M * (N-1))
return patch_attn_loss.squeeze()`
##########################################################################################################
`def cal_attnprime_blk7_batchwise(checkpoint, embed_dim, num_heads, images):
embed_dim = int(embed_dim / num_heads)
scale = embed_dim ** -0.5
teacher.eval()
# Obtain weights and bias for block 7
linear_weight_blk_7 = checkpoint["model"]['blocks.7.attn.qkv.weight'].cuda()
linear_bias_blk_7 = checkpoint["model"]['blocks.7.attn.qkv.bias'].cuda()
attn_inputs_blk7 = []
hook = teacher.module.blocks[7].attn.register_forward_hook(
lambda self, input, output: attn_inputs_blk7.append(input)
)
images = images.cuda() # Assuming 'images' is a batch of image tensors
with torch.no_grad():
outputs, output_feature = teacher(images)
B, N, C = attn_inputs_blk7[0][0].shape
uniform = (torch.ones(B, N - 1) / (N - 1)).float().cuda()
qkv_blk_7 = torch.bmm(attn_inputs_blk7[0][0], linear_weight_blk_7.unsqueeze(0).repeat(B, 1, 1).permute(0, 2, 1)) + linear_bias_blk_7
qkv_blk_7 = qkv_blk_7.reshape(B, N, 3, num_heads, embed_dim).permute(2, 0, 3, 1, 4)
q_blk_7, k_blk_7, v_blk_7 = qkv_blk_7[0], qkv_blk_7[1], qkv_blk_7[2]
attn_blk_7 = (q_blk_7 @ k_blk_7.transpose(-2, -1)) * scale
attn_blk_7 = attn_blk_7.softmax(dim=-1)
attnprime_blk_7 = attn_blk_7[:, 0, 0, 1:]
# Calculate the average attnprime_blk7 for each image in the batch
avg_attnprime = attnprime_blk_7.squeeze(dim=1) # Squeeze the dimensions for each image
attn_data = avg_attnprime.cpu().numpy()
attn_data = attn_data.reshape((B, 8, 8))
hook.remove()
return attn_data # Return the batch of attention data`
###########################################################################################################
The issue is generated images in this transformer augmented cGAN are not having better FID and actually if you see I am using additional loss (more relivant info )so that should not be the case can anyone help me with this
I tried calculating FID for various epoch