Custom Pytorch Lightning metric for loss function is causing errors during metric.forward & update calls

112 views Asked by At
  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.1.0
  • Python version: 3.10.13
  • Torchmetrics version: 1.2.0
  • Operating System: Windows 10

Expected behavior

I defined a custom class for a loss function and attempted to train a TFT. I placed a breakpoint in the loss function. The expectation was for the model to train using the custom loss function. When I run the model without specifying a loss, it runs fine.

My dataset is from a dataframe of 8 columns: 3 are positions (float), 3 are parameters (float), 'i' for group number (int), and 'time' for a time index that resets for each group. I defined the training set as 75% of max time and the validation set as the full dataset.

No issues when the loss=plPINNLoss is commented out, but when it's given errors are being thrown. This can only mean there's either an issue in my implementation or MultiHorizonMetric.

Problematic Behavior

I've been getting errors such as: 'Tensor' object has no attribute '_is_synced' from torchmetrics.metrics.py's forward(self, *args: Any, **kwargs: Any) function call where it checks if self._is_synced. A tensor is passed instead. This is coming from pytorch_forecasting.metrics.base_metrics.TorchMetricWrapper -> forward() which has return self.torchmetric.forward(y_pred_flattened, target_flattened, **kwargs).

A breakpoint in base_metrics.py lets me see that self.torchmetric is my plPinnLoss which inherits MultiHorizonMetric. When I modify the foward call in debug console to self.torchmetric.forward(self, y_pred_flattened, target_flattened, **kwargs), the error returned is then TypeError: Metric.reset() missing 1 required positional argument: 'self'. Again in debug console I try self.torchmetric.reset(self) which works. I make the modification to the source code in MultiHorizonMetrics, which leads to another error in MultiHorizonMetric.update on line 782:

lengths = torch.full((target.size(0),), fill_value=target.size(1), dtype=torch.long, device=target.device) where the error is IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1). The tensor passed in as target to update is 1 dimensional of size 128 (the batch size), so target.size(1) is invalid.

Something is clearly wrong, but I don't know enough about how to correct this. Is it in the construction of the metric used as a loss function?

Code

from pytorch_forecasting.metrics import MultiHorizonMetric
from pytorch_forecasting import TimeSeriesDataSet

train_AllTargets = TimeSeriesDataSet(data = train_df[lambda x: x.time <= int(0.75*t_max)],
                                     time_idx = "time",
                                     target = ['x','y','z','a','b','c'],
                                     group_ids = ['i'],
                                     static_reals = ['a','b','c'],
                                     time_varying_known_reals = ['time'],
                                     time_varying_unknown_reals = ['x','y','z']}
                                     )

train_loader = train_AllTargets.to_dataloader(train = True, batch_size=128)
val_loader = val_AllTargets.to_dataloader(train=False, batch_size = 128)

class plPINNLoss(MultiHorizonMetric):
    higher_is_better = False
    full_state_update = False
    def __init__(self):
        super(plPINNLoss, self).__init__()
        self.add_state("total", default= torch.tensor(0), dist_reduce_fx="mean")
    def phys_loss(self, y_pred: Dict[str, torch.Tensor], target: torch.Tensor) -> torch.Tensor:
        # Simple toy case
        x, y, z = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2]
        a, b, c = y_pred[:, 3], y_pred[:, 4], y_pred[:, 5]
        loss = torch.mean(x**2+y**2+z**2+a**2+b**2+c**2)
        total += loss
        
    def loss(self, y_pred: Dict[str, torch.Tensor], target: torch.Tensor) -> torch.Tensor:
        return self.total

tftmodel = TemporalFusionTransformer.from_dataset(
    dataset= train_AllTargets
    loss = plPINNLoss
    )

trainer = pl.Trainer(
    max_epochs=300,
    accelerator="auto",
    gradient_clip_val=0.1,
    callbacks = lr_logger,
    logger = logger,
    check_val_every_n_epoch=1
)
res = Tuner(trainer).lr_find(tftmodel, train_dataloaders=train_loader, val_dataloaders=val_loader)

0

There are 0 answers