Is this the correct implementation of a MAML model?

122 views Asked by At

I have used CLIP embeddings of image and text as the input and the output is a label ranging from 0 to 5 (6 way label). I tried to make an implemention of this multimodal 6 way classification using meta learning. I tried using a code which includes MAML (Model Agnostic Meta Learning). What am I doing wrong?

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import warnings

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

class CustomDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype=torch.float32).to(device)
        self.y = torch.tensor(y, dtype=torch.long).to(device)
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

class MAML(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MAML, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_samples = 10
        self.epochs = 20
        self.alpha = 0.001  # Adjusted learning rate
        self.beta = 0.001  # Adjusted meta learning rate
        self.theta = nn.Parameter(torch.randn(input_dim, output_dim).to(device))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        a = torch.matmul(x, self.theta)
        return self.softmax(a)

    def sample_points(self, k, x, y):
        indices = np.random.choice(len(x), k)
        return x[indices], y[indices]

    def train(self, x_train, y_train, x_val, y_val):
        train_dataset = CustomDataset(x_train, y_train)
        train_loader = DataLoader(train_dataset, batch_size=self.num_samples, shuffle=True)

        optimizer = optim.Adam(self.parameters(), lr=self.alpha)

        for e in range(1, self.epochs + 1):
            self.theta_ = []
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)

                y_hat = self.forward(x_batch)
                y_batch_encoded = torch.eye(self.output_dim, device=device)[y_batch]
                loss = -torch.mean(y_batch_encoded * torch.log(y_hat + 1e-7))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                self.theta_.append(self.theta.detach().clone())

            meta_gradient = torch.zeros_like(self.theta, dtype=torch.float32).to(device)
            for i in range(self.num_samples):
                x_test, y_test = self.sample_points(10, x_train, y_train)
                x_test = torch.tensor(x_test, dtype=torch.float32).to(device)
                y_pred = self.forward(x_test)
                y_test_encoded = torch.eye(self.output_dim)[y_test].to(device)
                meta_gradient += torch.matmul(x_test.T, (y_pred - y_test_encoded)) / self.num_samples

            self.theta.data -= self.beta * meta_gradient

            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=UserWarning)
                x_val = torch.tensor(x_val, dtype=torch.float32).to(device).clone().detach().requires_grad_(True)
            y_val_pred = self.forward(x_val)
            val_loss = -torch.mean(torch.eye(self.output_dim, device=device)[y_val] * torch.log(y_val_pred + 1e-7))

    def predict(self, x):
        with torch.no_grad():
            x = torch.tensor(x, dtype=torch.float32).to(device)
            y_pred = self.forward(x)
            _, predictions = torch.max(y_pred, dim=1)
            return predictions.cpu().numpy()

# Load the dataset
data = pd.read_csv('data/text_image_embeddings.csv')
x_text = data['text_embedding'].str.split('\t', expand=True).astype(float).values
x_image = data['image_embedding'].str.split('\t', expand=True).astype(float).values
x = np.concatenate((x_text, x_image), axis=1)
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(data['label'])
len(data)
num_labels = len(label_encoder.classes_)
print(num_labels)
models = []
accuracies = []
for i in range(num_labels):
    # Divide data into train and validation for the current label/task
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.8, stratify=y, random_state=i)
    
    # Create the CustomDataset for the current label/task
    train_dataset = CustomDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
    
    # Create the MAML model for the current label/task
    model = MAML(input_dim=x.shape[1], output_dim=num_labels).to(device)
    models.append(model)
    
    # Train the model for the current label/task
    model.train(x_train, y_train, x_test, y_test)
    
    # Calculate accuracy on the validation dataset for the current label/task
    val_predictions = model.predict(x_test)
    accuracy = accuracy_score(y_test, val_predictions)
    accuracies.append(accuracy)

# Print the accuracies for each label/task
for label, accuracy in zip(label_encoder.classes_, accuracies):
    print(f"Label: {label}, Accuracy: {accuracy:.4f}")
1

There are 1 answers

4
user22238474 On

It seems to be mostly correct but something is wrong with respect to the way the accuracy is calculated.

from sklearn.model_selection import StratifiedKFold

# ... (the rest of the code remains unchanged) ...

# Initialize the number of outer and inner folds for nested cross-validation
num_outer_folds = 5
num_inner_folds = 3

# Perform nested cross-validation for each label
for label_idx, label in enumerate(label_encoder.classes_):
    # Get the indices of data points corresponding to the current label
    label_indices = np.where(y == label_idx)[0]

    # Outer loop: Perform stratified k-fold cross-validation for evaluation
    outer_kfold = StratifiedKFold(n_splits=num_outer_folds, shuffle=True, random_state=42)

    # List to store accuracy for each outer fold
    outer_fold_accuracies = []

    for outer_fold_idx, (train_outer_idx, test_outer_idx) in enumerate(outer_kfold.split(label_indices, y[label_indices])):
        # Split data into outer training and test sets for the current outer fold
        x_train_outer, x_test_outer = x[label_indices[train_outer_idx]], x[label_indices[test_outer_idx]]
        y_train_outer, y_test_outer = y[label_indices[train_outer_idx]], y[label_indices[test_outer_idx]]

        # Inner loop: Perform stratified k-fold cross-validation for model selection
        inner_kfold = StratifiedKFold(n_splits=num_inner_folds, shuffle=True, random_state=42)

        # List to store accuracy for each inner fold
        inner_fold_accuracies = []

        for inner_fold_idx, (train_inner_idx, val_inner_idx) in enumerate(inner_kfold.split(x_train_outer, y_train_outer)):
            # Split data into inner training and validation sets for the current inner fold
            x_train_inner, x_val_inner = x_train_outer[train_inner_idx], x_train_outer[val_inner_idx]
            y_train_inner, y_val_inner = y_train_outer[train_inner_idx], y_train_outer[val_inner_idx]

            # Create the CustomDataset for the current inner fold
            train_dataset = CustomDataset(x_train_inner, y_train_inner)
            train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

            # Create the MAML model for the current label
            model = MAML(input_dim=x.shape[1], output_dim=num_labels).to(device)

            # Train the model for the current inner fold
            model.train(x_train_inner, y_train_inner, x_val_inner, y_val_inner)

            # Calculate accuracy on the validation dataset for the current inner fold
            val_predictions = model.predict(x_val_inner)
            accuracy = accuracy_score(y_val_inner, val_predictions)
            inner_fold_accuracies.append(accuracy)

        # Calculate and report average accuracy for the current label across all inner folds
        avg_inner_accuracy = np.mean(inner_fold_accuracies)

        # Store the average accuracy for the current outer fold
        outer_fold_accuracies.append(avg_inner_accuracy)

    # Calculate and report average accuracy for the current label across all outer folds
    avg_accuracy = np.mean(outer_fold_accuracies)
    print(f"Label: {label}, Average Accuracy: {avg_accuracy:.4f}")