I suspect it is a limitation of mlflow not fully supporting plain pytorch (not torchlightning), but maybe someone knows a workaround that can be used on a large scale. When I run the experiment, code and artifacts are logged correctly, but in the conda.yaml
torchvision is completely missed.
The conda.yaml file contains most critical dependencies:
channels:
- conda-forge
dependencies:
- python=3.9.18
- pip<=23.3
- pip:
- mlflow==2.7.1
- cloudpickle==2.2.1
- numpy==1.26.0
- packaging==23.2
- pandas==2.1.1
- pyyaml==6.0.1
- torch==1.13.0
- tqdm==4.66.1
name: mlflow-env
I tried constructing the environment in conda and after manually adding the correct version of torchvision I can run the experiment on this environment.
reduced pseudocode:
import
import mlflow.pytorch
import logging
logging.getLogger("mlflow").setLevel(logging.DEBUG)
mlflow.set_tracking_uri(<server URI>)
mlflow.set_experiment(<exp_name>)
def train(args):
for epoch in epochs:
# put metrics in metrics dictionary
mlflow.log_metrics(metrics, epoch)
with mlflow.start_run():
log_scalar("lr", learning_rate)
log_scalar("epochs", epochs)
mlflow.log_param("optimiser", f"{optimizer_ft}")
mlflow.log_param("learning_rate", learning_rate)
train(args)
mlflow.pytorch.log_model(model_ft
, artifact_path="pytorch_model"
, pickle_module=pickle
, code_paths=[<source_files>])