How to suppress PyTorch Lightning Logging Output in DARTs

1k views Asked by At

I am using the python DARTs package, and would like to run the prediction method without generating output logs. I appear unable to do so; all of the suggestions I've seen do not work, even when I attempt to apply them to DARTs source code.

Here is a reproducible example, which generates the output logs:

import darts
import datetime
import pandas as pd
import numpy as np

#reprex
yseries = np.random.rand(100)
xseries = np.random.rand(100)
zseries = np.random.rand(100)

d = datetime.datetime.now()
tseries = [d + datetime.timedelta(days=i) for i in range(100)]

df = pd.DataFrame({"y":yseries,"x":xseries,"z":zseries,"t":tseries})
yseries = TimeSeries.from_dataframe(df, "t", "y").astype(np.float32)
xseries = TimeSeries.from_dataframe(df, "t", ["x","z"]).astype(np.float32)

from darts.models import NBEATSModel, TCNModel

model3 = NBEATSModel(input_chunk_length=20,                 # init
                    output_chunk_length=1,n_epochs=50,
    torch_metrics=torch_metrics)

yseries_train, yseries_val = yseries.split_before(0.5)
xseries_train, xseries_val = xseries.split_before(0.5)

model3.fit(series=yseries_train,past_covariates=xseries_train,max_samples_per_ts=50)

for t in range(0,100):
    forecast = model3.predict(n=1,series=yseries_train,past_covariates=xseries_train)

Which gives the following (100 times):

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

Here are the leads I've attempted to follow, to no avail: https://github.com/Lightning-AI/lightning/issues/2757, How to disable logging from PyTorch-Lightning logger?

1

There are 1 answers

0
uchiiii On

I guess that is solved by this below.

import logging
logging.getLogger("pytorch_lightning.utilities.rank_zero").addHandler(logging.NullHandler())
logging.getLogger("pytorch_lightning.accelerators.cuda").addHandler(logging.NullHandler())

or you can set the log level by

import logging
logging.getLogger("pytorch_lightning.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)