I am trying to carry out hyperaprameter optimization of a TFT model from the Pytorch forecasting library. For this I am using pytorch lightning for training and Optuna for hyperparameter optimization.
I am facing issues with regards to the pruner implementation. I have found the following four approaches all of which either wont work with my implementation or did not work when I thried them out:
trial.report()...if trial.prune(): I cannot use this method as I am directly using the pl.trainer and model.fit to fit the model.
PytorchLightningPruningCallbackAdjusted:
import
.
..
...
class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback):
pass
def optimize_hyperparameters_modified(.... pruner: optuna.pruners.BasePruner = optuna.pruners.SuccessiveHalvingPruner(min_resource = 1, reduction_factor = 2)):
.
..
...
....
default_trainer_kwargs = (....., callbacks = [PyTorchLightningPruningCallbackAdjusted(trial, monitor="val_loss")],....)
trainer = pl.trainer(**defaul_trainer_kwargs)
model.fit()
I made changes to the code in source: https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting/models/temporal_fusion_transformer/tuning.html#optimize_hyperparameters
For the third attempt I tried replacing the early_stop_callback to the trainer with the PytorchLightningPrunningCallbackAdjusted instead of having it as a separate callback.
In the fourth attempt I tried to define the class PytorchLightningPruningCallbackAdjusted as follows:
class PyTorchLightningPruningCallbackAdjusted(pl.Callback, PyTorchLightningPruningCallback):
def __init__(self, trial, monitor: str):
self._trial = trial
self._monitor = monitor
self._best_score = float('inf')
def on_epoch_end(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.lightningModule',
):
logs = trainer.callback_metrics
if self._monitor not in logs:
print('no logs')
return
current_score = logs[self._monitor]
if torch.isnan(current_score):
print('current score is nan')
return
self._trial.report(current_score, step = trainer.global_step)
if self._trial.should_prune():
message = f'Trial was pruned at epoch {trainer.current_epoch}.'
raise optuna.exceptions.TrialPrunned(message)
else:
print('Trial checked but not pruned')
if current_score < self._best_score:
self._best_score = current_score
self._trial.set_user_attr('best_score', self._best_score)
print(f'Current_score = {current_score} and Best_score = {self._best_score}')
threshold = self._trial.params.get('threshold')
print(f'Threshold: {threshold}')
if threshold is None:
return
if current_score > threshold:
self._trial.set_user_attr('threshold', current_score)
self._trial.set_params(threshold = current_score)
print(f'Threshold_ reset: {threshold}')
But the callback isnt being called.
I have been struggling with this problem for quite sometime now. Any help will be appreciated!
Quick Summary Following approaches were implemented:
- PyTorchLightningPruningCallbackAdjusted introduced as separate callback
- PyTorchLightningPruningCallbackAdjusted introduced as early stop callback
- Defined PyTorchLightningPruningCallbackAdjusted separately
In all three cases the training proceeds normally without any pruning.