Deativate Pytorch Lightning Module Model Logging During Prediction

283 views Asked by At

I am trying to serve a Pytorch Forecasting model using FastAPI. I am loading the model from a checkpoint using the following code on startup:

model = BaseModel.load_from_checkpoint(model_path)

model.eval()

Although the predictions do come up fine, every time there's a new version generated in the lightining_logs folder with the hyperparameters stored in a new file after each prediction. I use the following code for the predictions:

raw_predictions = model.predict(df, mode="raw", return_x=True)

How can I stop logging when I serve the model for predictions?

2

There are 2 answers

0
tzik On BEST ANSWER

Someone posted the answer on GitHub around the same time I discovered it after doing lots of reading. It's not that evident, at least for me:

trainer_kwargs={'logger':False}

In the case of the code in my question the prediction part would turn into:

raw_predictions = model.predict(df, mode="raw", return_x=False, trainer_kwardgs=dict(accelarator="cpu|gpu", logger=False))
1
Edwin Cheong On

Hi heres what i normally do

  1. Save as a normal pt file pytorch lighthning is fully compatible with pytorch (of course you have to redesign from a LightningModule to a normal nn.Module class)
  2. Save as onnx model
from model import Model
import pytorch_lightning as pl
import torch

model:pl.LightningModule = Model()
torch.save(model.state_dict(), 'weights.pt')

# Or save to onnx
torch.onnx.export(model, (inputs), fname))