Can I Use MLflow Autologging with Vanilla PyTorch?

46 views Asked by At

I came across the following statement in the official MLflow documentation:

Autologging support for vanilla PyTorch (i.e., models that only subclass torch.nn.Module) only autologs calls to torch.utils.tensorboard.SummaryWriter’s add_scalar and add_hparams methods to mlflow.

Based on this, I assumed that even with vanilla PyTorch, the add_scalar and add_hparams methods would be automatically executed with autologging.

Therefore, I ran the following sample code, but nothing was recorded in MLflow:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import mlflow

mlflow.autolog()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

for epoch in range(2):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')
            running_loss = 0.0

Is there something I'm doing wrong?

I experimented with the following versions:

  • mlflow==2.11.1 (latest)
  • torch==2.2.1
  • torchvision==2.2.1
  • tensorboard==2.16.2
0

There are 0 answers