How to integrate stable_baselines3 with dagshub and MLflow?

132 views Asked by At

I am trying to integrate stable_baselines3 in dagshub and MlFlow. I am new to MLOPS

Here is a sample code that is easy to run:

import mlflow
import gym
from gym import spaces
import numpy as np
from stable_baselines3 import PPO
import os

os.environ['MLFLOW_TRACKING_USERNAME'] = "correct_dagshub_username"
os.environ['MLFLOW_TRACKING_PASSWORD'] = "correct_dagshub_token"
os.environ['MLFLOW_TRACKING_URI'] = "correct_URL")

# Create a simple custom gym environment
class SimpleEnv(gym.Env):
    def __init__(self):
        super(SimpleEnv, self).__init__()
        self.action_space = spaces.Discrete(3)
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,))

    def step(self, action):
        return np.array([0, 0, 0, 0]), 0, False, {}

    def reset(self):
        return np.array([0, 0, 0, 0])



# Create and train the model
env = SimpleEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=1000)

# Save the model using MLflow
mlflow.log_artifact("model.zip")

# Load the model from MLflow using the captured run_id
run_id = mlflow.active_run().info.run_id
loaded_model = mlflow.pyfunc.load_model(f"runs:/{run_id}/model")

The problem is that I always get this error:

---------------------------------------------------------------------------
MlflowException                           Traceback (most recent call last)
Cell In[13], line 11
      6 # Now the model is saved to MLflow with the corresponding run_id
      7 
      8 # Step 5: Load the model from MLflow
      9 run_id = mlflow.active_run().info.run_id
---> 11 loaded_model = mlflow.pytorch.load_model(f"runs:/{run_id}/model")

File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\pytorch\__init__.py:698, in load_model(model_uri, dst_path, **kwargs)
    637 """
    638 Load a PyTorch model from a local file or a run.
    639 
   (...)
    694     predict X: 30.0, y_pred: 60.48
    695 """
    696 import torch
--> 698 local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
    699 pytorch_conf = _get_flavor_configuration(model_path=local_model_path, flavor_name=FLAVOR_NAME)
    700 _add_code_from_conf_to_system_path(local_model_path, pytorch_conf)

File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\tracking\artifact_utils.py:100, in _download_artifact_from_uri(artifact_uri, output_path)
     94 """
     95 :param artifact_uri: The *absolute* URI of the artifact to download.
     96 :param output_path: The local filesystem path to which to download the artifact. If unspecified,
     97                     a local output path will be created.
     98 """
     99 root_uri, artifact_path = _get_root_uri_and_artifact_path(artifact_uri)
--> 100 return get_artifact_repository(artifact_uri=root_uri).download_artifacts(
    101     artifact_path=artifact_path, dst_path=output_path
    102 )

File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\store\artifact\runs_artifact_repo.py:125, in RunsArtifactRepository.download_artifacts(self, artifact_path, dst_path)
    110 def download_artifacts(self, artifact_path, dst_path=None):
    111     """
    112     Download an artifact file or directory to a local directory if applicable, and return a
    113     local path for it.
   (...)
    123     :return: Absolute path of the local filesystem location containing the desired artifacts.
    124     """
--> 125     return self.repo.download_artifacts(artifact_path, dst_path)

File ~\anaconda3\envs\metatrader\lib\site-packages\mlflow\store\artifact\artifact_repo.py:200, in ArtifactRepository.download_artifacts(self, artifact_path, dst_path)
    197         failed_downloads[path] = repr(e)
    199 if failed_downloads:
--> 200     raise MlflowException(
    201         message=(
    202             "The following failures occurred while downloading one or more"
    203             f" artifacts from {self.artifact_uri}: {failed_downloads}"
    204         )
    205     )
    207 return os.path.join(dst_path, artifact_path)

MlflowException: The following failures occurred while downloading one or more artifacts from URL/artifacts: {'model': 'MlflowException("API request to some api', port=443): Max retries exceeded with url: some_url (Caused by ResponseError(\'too many 500 error responses\'))")'}

Stable_baselines3 save the model as a zip file, I can see the artifact in MLflow but whatever I do cannot load the model from MLflow. I also tried it with

loaded_model = mlflow.pytorch.load_model(model_uri)

Any help would be greatly appreciated

1

There are 1 answers

1
Jinen Setpal On BEST ANSWER

When I ran your example, I got a different error:

Traceback (most recent call last):
  File "/tmp/stable_baselines3/./train.py", line 36, in <module>
    mlflow.pytorch.log_model(model, "model")
  File "/tmp/stable_baselines3/.venv/lib/python3.11/site-packages/mlflow/pytorch/__init__.py", line 293, in log_model
    return Model.log(
           ^^^^^^^^^^
  File "/tmp/stable_baselines3/.venv/lib/python3.11/site-packages/mlflow/models/model.py", line 572, in log
    flavor.save_model(path=local_path, mlflow_model=mlflow_model, **kwargs)
  File "/tmp/stable_baselines3/.venv/lib/python3.11/site-packages/mlflow/pytorch/__init__.py", line 455, in save_model
    raise TypeError("Argument 'pytorch_model' should be a torch.nn.Module")
TypeError: Argument 'pytorch_model' should be a torch.nn.Module

I am using gym==0.26.2, mlflow==2.5.0 and stable-baselines3==2.0.0 on Python 3.11.3. I think the error is a lot clearer in this case - PPO isn't a torch model, and I couldn't find information on autologging stable_baselines3 models. So I set up a class through pyfunc:

class PPOModelWrapper(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.model = PPO.load(context.artifacts["path"])

    def predict(self, context, model_input):
        action, states = self.model.predict(model_input)
        return {"action": action, "states": states}

From there, you can log the model using mlflow.pyfunc.log_model.

I've added the source code to the following repository: https://dagshub.com/jinensetpal/stable_baselines3, the logged model can be seen at: https://dagshub.com/jinensetpal/stable_baselines3.mlflow/#/experiments/0/runs/1f9e29528b5649b6a56a37ffb6a79a28/artifactPath/model

Hope this helps!