PyTorch Temporal Fusion Transformer prediction output length

1.1k views Asked by At

I have trained a temporal fusion transformer on some training data and would like to predict on some unseen data. To do so, I'm using the pytorch_forecasting TimeSeriesDataSet data structures

testing = TimeSeriesDataSet.from_dataset(training, df[lambda x: x.year > validation_cutoff], predict=True, stop_randomization=True)

with

df[lambda x: x.year > validation_cutoff].shape
(97036, 13)

Given that

testing.data['reals'].shape
torch.Size([97036, 9])

I would expect to receive a prediction output vector containing 97036 rows. So I proceed to generate my predictions like so

test_dataloader = testing.to_dataloader(train=False, batch_size=128 * 10, num_workers=0)
raw_predictions, x = best_tft.predict(testing, mode="raw", return_x=True)

However, I receive an output of the size

raw_predictions['prediction'].shape
torch.Size([25476, 1, 7])

Why are some of these 97036 observations being removed?

Or else, how can I find out which if these 97036 observations are being dropped and why the are being removed?

2

There are 2 answers

0
Ayda Farhadi On

Get rid of mode="raw" in order to get a forecast on the max_prediction horizon range. It is going to give one forecast for each individual row of group and columns of max_prediction horizon.

torch.Size([25476, 1, 7])

This gives one prediction, per one granular group, at a time on the test set, depending on the date range of the test set.

0
ptushev On

In the source code of the TimeSeriesDataSet there are filters to remove short time series. When you set predict=True in TimeSeriesDataSet.from_dataset, it sets the min_prediction_length to max_prediction_length. Then, when the actual test dataloader is to be created, all of the time series that are shorter than min_prediction_length are removed, which removes the entire data from the testing set, which leaves you with exactly 0 observations. Exactly why it is implemented in this way, I don't know. To make predictions just set:

testing = TimeSeriesDataSet.from_dataset(training, df[lambda x: x.year > validation_cutoff], predict=False, stop_randomization=True)