Transformer augmented cGAN

27 views Asked by At

I am working on a problem wherein I am trying to improve the quality of generated images in cGAN using transformer. What basically I am trying to acheive is that I have a basic generator and discriminator as under:

# Generator
`class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # Define the generator layers
        self.label_embedding = nn.Embedding(10, latent_dim)  # 10 classes for MNIST
        self.main = nn.Sequential(
            nn.Linear(latent_dim * 2, hidden_dim),
            #nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 784),  # 28x28 image size
            nn.Tanh()
        )

    def forward(self, noise, class_label):
        # Forward pass for the generator
        class_embedding = self.label_embedding(class_label)
        combined_input = torch.cat((noise, class_embedding), dim=1)
        return self.main(combined_input).view(noise.size(0), 1, 28, 28)

# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Define the discriminator layers
        self.label_embedding = nn.Embedding(10, 784)  # 10 classes for MNIST
        self.main = nn.Sequential(
            nn.Linear(784 * 2, hidden_dim),
            #nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, image, class_label):
        # Forward pass for the discriminator
        class_embedding = self.label_embedding(class_label)
        class_embedding = class_embedding.view(class_embedding.size(0), -1)
        image = image.view(image.size(0), -1)
        combined_input = torch.cat((image, class_embedding), dim=1)
        return self.main(combined_input)`

now I also have a pretrained transformer which i am using to improve the quality of generated images .What I am utilizing for this is the attention probe(first row of attention map of transformer ie class attention to all the tokens) I am calculating the attention probe of the generated images then taking the loss between the generaed attention probe and the class attention probe( average of attention probe of all the images in the training dataset of a particular class) and then using this loss along with adversal loss to train the generator as under:

attn_data = cal_attnprime_blk7_batchwise(checkpoint, embed_dim, num_heads, x_fake) # attention probe of generated images attention_loss1 = patch_attention_probe_loss(class_avg_attnprime_batch,attn_data_tensor) # loss total_g_loss = 0.2 *g_loss + 0.8 *attention_loss1 # total loss

here are the function for cal_attnprime_blk7_batchwise and patch_attention_probe_loss

def patch_attention_probe_loss(feature_T, feature_S):

    B=  feature_T.shape[0]
    M = feature_T.shape[1]
    N = feature_T.shape[2]


    feature_T_norm = F.normalize(feature_T, p=2, dim=1)
    feature_T_norm = F.normalize(feature_T_norm, p=2, dim=2)
    feature_S_norm = F.normalize(feature_S, p=2, dim=1)
    feature_S_norm = F.normalize(feature_S_norm, p=2, dim=2)

 
    patch_attn_diff = feature_T_norm - feature_S_norm
    patch_attn_loss = (patch_attn_diff * patch_attn_diff).sum() / (B* M * (N-1))

    return patch_attn_loss.squeeze()`


##########################################################################################################

`def cal_attnprime_blk7_batchwise(checkpoint, embed_dim, num_heads, images):
    embed_dim = int(embed_dim / num_heads)
    scale = embed_dim ** -0.5
    teacher.eval()

    # Obtain weights and bias for block 7
    linear_weight_blk_7 = checkpoint["model"]['blocks.7.attn.qkv.weight'].cuda()
    linear_bias_blk_7 = checkpoint["model"]['blocks.7.attn.qkv.bias'].cuda()

    attn_inputs_blk7 = []

    hook = teacher.module.blocks[7].attn.register_forward_hook(
        lambda self, input, output: attn_inputs_blk7.append(input)
    )

    images = images.cuda()  # Assuming 'images' is a batch of image tensors

    with torch.no_grad():
        outputs, output_feature = teacher(images)

        B, N, C = attn_inputs_blk7[0][0].shape
        uniform = (torch.ones(B, N - 1) / (N - 1)).float().cuda()

        qkv_blk_7 = torch.bmm(attn_inputs_blk7[0][0], linear_weight_blk_7.unsqueeze(0).repeat(B, 1, 1).permute(0, 2, 1)) + linear_bias_blk_7
        qkv_blk_7 = qkv_blk_7.reshape(B, N, 3, num_heads, embed_dim).permute(2, 0, 3, 1, 4)
        q_blk_7, k_blk_7, v_blk_7 = qkv_blk_7[0], qkv_blk_7[1], qkv_blk_7[2]
        attn_blk_7 = (q_blk_7 @ k_blk_7.transpose(-2, -1)) * scale
        attn_blk_7 = attn_blk_7.softmax(dim=-1)
        attnprime_blk_7 = attn_blk_7[:, 0, 0, 1:]

        # Calculate the average attnprime_blk7 for each image in the batch
        avg_attnprime = attnprime_blk_7.squeeze(dim=1)  # Squeeze the dimensions for each image

        attn_data = avg_attnprime.cpu().numpy()
        attn_data = attn_data.reshape((B, 8, 8))

    hook.remove()

    return attn_data  # Return the batch of attention data`


###########################################################################################################

The issue is generated images in this transformer augmented cGAN are not having better FID and actually if you see I am using additional loss (more relivant info )so that should not be the case can anyone help me with this

I tried calculating FID for various epoch

0

There are 0 answers