How to implement optuna pruner in pytorch lightning?

57 views Asked by At

I am trying to carry out hyperaprameter optimization of a TFT model from the Pytorch forecasting library. For this I am using pytorch lightning for training and Optuna for hyperparameter optimization.

I am facing issues with regards to the pruner implementation. I have found the following four approaches all of which either wont work with my implementation or did not work when I thried them out:

  1. trial.report()...if trial.prune(): I cannot use this method as I am directly using the pl.trainer and model.fit to fit the model.

  2. PytorchLightningPruningCallbackAdjusted:

import 
.
..
...
class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback):
  pass

def optimize_hyperparameters_modified(.... pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(min_resource = 1, reduction_factor = 2)):
  .
  ..
  ...
  ....
  default_trainer_kwargs = (....., callbacks = [PyTorchLightningPruningCallbackAdjusted(trial,    monitor="val_loss")],....)
  trainer = pl.trainer(**defaul_trainer_kwargs)
  model.fit()

I made changes to the code in source: https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting/models/temporal_fusion_transformer/tuning.html#optimize_hyperparameters

  1. For the third attempt I tried replacing the early_stop_callback to the trainer with the PytorchLightningPrunningCallbackAdjusted instead of having it as a separate callback.

  2. In the fourth attempt I tried to define the class PytorchLightningPruningCallbackAdjusted as follows:

class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback):
    def __init__(self, trial, monitor: str):

        self._trial = trial
        self._monitor = monitor
        self._best_score = float('inf')


    def on_epoch_end(
        self,
        trainer: 'pl.Trainer',
        pl_module: 'pl.lightningModule',
    ):

        logs = trainer.callback_metrics
        if self._monitor not in logs:
            print('no logs')
            return

        current_score = logs[self._monitor]
        if torch.isnan(current_score):
            print('current score is nan')
            return

        self._trial.report(current_score, step = trainer.global_step)
        if self._trial.should_prune():
            message = f'Trial was pruned at epoch {trainer.current_epoch}.'
            raise optuna.exceptions.TrialPrunned(message)
        else:
            print('Trial checked but not pruned')

        if current_score < self._best_score:
            self._best_score = current_score
            self._trial.set_user_attr('best_score', self._best_score)
        print(f'Current_score = {current_score} and Best_score = {self._best_score}')

        threshold = self._trial.params.get('threshold')
        print(f'Threshold: {threshold}')
        if threshold is None:
            return

        if current_score > threshold:
            self._trial.set_user_attr('threshold', current_score)
            self._trial.set_params(threshold = current_score)
            print(f'Threshold_ reset: {threshold}')

But the callback isnt being called.

I have been struggling with this problem for quite sometime now. Any help will be appreciated!

Quick Summary Following approaches were implemented:

  1. PyTorchLightningPruningCallbackAdjusted introduced as separate callback
  2. PyTorchLightningPruningCallbackAdjusted introduced as early stop callback
  3. Defined PyTorchLightningPruningCallbackAdjusted separately

In all three cases the training proceeds normally without any pruning.

0

There are 0 answers