checkpoint creation for finetuned stable diffusion model

35 views Asked by At

I am trying to save the finetuned parameter checkpoints after each epoch for a stable diffusion model. If I use this: model_state_dict': pipe.unet.state_dict(), #saves the entire model. While this model_state_dict': fine_tuned_state_dict will not save nothing 64bytes, optimizer values also look wrong. Not sure if the problem is that the optimizer is not updating any of the parameters or my checkpoint saving is not correct. The loss value also seems to fluctuate and not steadily go down. I seem to be able to save the entire model. This is the print result after runing 13568 batches:

Warning: fine_tuned_state_dict is empty. No model parameters might have received gradients during training.
Checkpoint saved to kan_sd_v1_4_bs128_e10-2_LMSE_interrupted.pt
model_state_dict: 64 bytes
batch_size: 28 bytes
training_steps: 28 bytes
lr: 24 bytes
optimizer: 232 bytes
loss: 232 bytes
sum of parameters 859520964
results after batches  13568


def create_optimizer(pipe,lr):
    """
    Create optimizer for training.

    Args:
        pipe: The pipeline instance.
        lr (float): Learning rate.

    Returns:
        torch.optim.Optimizer: Optimizer for training.
    """

    # Get the encoder and decoder parameters from the pipeline
    # Replace encoder and decoder with the actual names of the relevant components in your StableDiffusionPipeline.
    # Assuming your StableDiffusionPipeline is named 'pipe'
    encoder_params = pipe.vae.parameters()
    decoder_params = pipe.text_encoder.parameters()

    # Combine the parameters
    optimizer = optim.Adam(list(encoder_params) + list(decoder_params), lr=lr)

    return optimizer


def save_checkpoint(file_name, optimizer, loss_function, pipe, epochs, batch_size, batches, epoch, lr):
    """
    Save the training checkpoint.

    Args:
        file_name (str): Name of the checkpoint file.
        optimizer (torch.optim.Optimizer): Optimizer used for training.
        loss_function (str): Type of loss function.
        pipe: The pipeline instance.
        epochs (int): Number of epochs.
        batch_size (int): Batch size.
        batches (int): Number of batches processed.
        epoch (int): Current epoch.
        lr (float): Learning rate.
    """
    try:
        # Define the loss function type and its parameters based on the actual loss function used
        loss_function_type = "perceptual_loss" if loss_function == "perceptual_loss" else "criterion_mse"
        loss_function_parameters = {
            'lpips_model': vgg if loss_function == "perceptual_loss" else None,
            'target_loss_weight': 5e-2 if loss_function == "perceptual_loss" else None
        }

        # Get the initial model state dictionary
        #initial_model_state_dict = pipe.unet.state_dict()


        
         # **Revised:** Consider filtering for fine-tuned modules if applicable
        #fine_tuned_state_dict = {key: value for key, value in pipe.unet.state_dict().items() if key in optimizer.state_dict()['state']}
        
    # methodd 1 for getting fine_tuned state
        fine_tuned_state_dict = {}
        for name, param in pipe.unet.named_parameters():
            if param.requires_grad and name in optimizer.state_dict()['state']:
                fine_tuned_state_dict[name] = param

       
    #method 2 for getting fine_tuned state
        fine_tuned_state_dict_size = sum(value.numel() * value.element_size() for value in fine_tuned_state_dict.values())
        print(f"**** claude Size of fine_tuned_state_dict: {fine_tuned_state_dict_size / (1024 ** 2):.2f} MB")

                
       # **Error Handling: Check for Empty fine_tuned_state_dict**  
        if not fine_tuned_state_dict:
            print("Warning: fine_tuned_state_dict is empty. No model parameters might have received gradients during training.")

        checkpoint = {
            #'model_state_dict': pipe.unet.state_dict(), #saves the entire model
            'model_state_dict': fine_tuned_state_dict,
            'batch_size': batch_size,
            'training_steps': epochs,
            'lr': lr,
            'optimizer': {
                'type': 'Adam',
                'state_dict': optimizer.state_dict(),
                'lr': lr
            },
            'loss': {
                'type': loss_function_type,
                'parameters': loss_function_parameters
            }
        }
    
        
        torch.save(checkpoint, file_name)
        print(f"Checkpoint saved to {file_name}")
        
        
        
        for key, value in checkpoint.items():
            print(f"{key}: {sys.getsizeof(value)} bytes")

    
        print('sum of parameters', sum(p.numel() for p in pipe.unet.parameters()))

    except Exception as e:
        print(f"Error saving checkpoint: {e}")


# Define function to train model
def train(epochs, batch_size, lr, num_steps, train_dataset, device,  loss_function, pipe):
    """
    Train the model.

    Args:
        epochs (int): Number of epochs.
        batch_size (int): Batch size.
        lr (float): Learning rate.
        num_steps (int): Number of inference steps.
        train_dataset: The training dataset.
        device: The device for computations.
        dataset_length (int): Length of the training dataset.
    """
    # Instantiate the CLIP tokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    # Pre-tokenize prompts
    tokenized_prompts = pre_tokenize_prompts(train_dataset, tokenizer)
    
    # Define the optimizer
    #optimizer = create_optimizer(pipe, lr)
    optimizer = optim.Adam(pipe.unet.parameters())  #gemini suggestion
    # Define the loss function
    loss_function = "criterion_mse"  # or "perceptual_loss" 
    
    try:
        for epoch in range(epochs):
            
            dataset_length = len(train_dataset)
            for i in range(0, len(train_dataset), batch_size):
                batch = train_dataset[i:i + batch_size]
                meanings = [str(entry['meaning']) for entry in batch]
                images = [entry['image'] for entry in batch]


                 # Convert list of meanings to a single string
                combined_meanings = ' '.join(meanings)

                # Zero the gradients
                optimizer.zero_grad()

                # Use pre-tokenized prompts
                start_idx = i * len(tokenized_prompts)
                end_idx = (i + batch_size) * len(tokenized_prompts)
                batch_tokenized_prompts = tokenized_prompts[start_idx:end_idx]

                with torch.cuda.device(device):
                    with autocast():
                        outputs = pipe(combined_meanings, tokenized_prompts=batch_tokenized_prompts, num_inference_steps=num_steps)["images"]

                # Normalize first before tensor conversion to avoid values becoming zero
                normalize = lambda x: (np.array(x) - 0) / 255.0  # Normalize to [0, 1]

                # Convert images to tensors, grayscale, and normalize
                images_tensor = torch.stack([to_tensor(to_grayscale(image)) for image in images])
                images_tensor = torch.stack([torch.tensor(normalize(image), dtype=torch.float32) for image in images])

                # Convert images to tensors, grayscale, and normalize
                outputs_tensor = torch.stack([to_tensor(to_grayscale(output)) for output in outputs])
                outputs_tensor = torch.stack([torch.tensor(normalize(output), dtype=torch.float32) for output in outputs])

                 # Set requires_grad=True after creating the tensor
                outputs_tensor.requires_grad_(True)   
                images_tensor.requires_grad_(True)

                # Reduce images channels  
                images_tensor = images_tensor[:, :, :, 0].unsqueeze(1)
                #print(f"1. images_tensor shape: {images_tensor.shape}")

                # Resize outputs  
                outputs_resized = F.interpolate(outputs_tensor, size=(100,100), 
                                                mode='bilinear', align_corners=False)
                # Reduce outputs channels
                outputs_resized = outputs_resized[:, 0:1, :, :]  
                #print(f"3. outputs_resized shape: {outputs_resized.shape}")


                # Move input tensor to the same device as the model's weight tensor (GPU)
                outputs_resized = outputs_resized.to(device)

                # Move feature extractor to the same device as the input tensor
                images_tensor = images_tensor.to(device)

               
                if loss_function == 'MSE': 
                    #calculate MSE loss
                    criterion_mse = nn.MSELoss()
                    loss = criterion_mse(outputs_resized, images_tensor) #outputs_resized it was
                else:
                    # Calculate perceptual Loss 
                    loss = perceptual_loss(outputs_resized, images_tensor, feature_extractor)

                print('Loss = ', loss.item())

                # Backward pass
                loss.backward()
                optimizer.step()
            
            #save checkpoint after each opech completion
            file_name = f'kan_sd_v1_4_bs{batch_size}_e{epochs}-{epoch}_L{loss_function}_interrupted.pt'
            save_checkpoint(file_name, optimizer, loss_function, pipe, epochs, batch_size, batches, epoch, lr)

train(epochs, batch_size, lr, num_steps, train_dataset, device,  loss_function, pipe)
0

There are 0 answers