Temporal Fusion Transformer model training encountered Gradient Vanishing

114 views Asked by At

I am training financial data with Temporal Fusion Transformer. Though this model has skipping connection and residual connection to enhance information. I believe it encountered gradient vanishing at least at final output layer. The Temporal Fusion Transformer implementation is inspired from open source, including pytorch-forecasting. I customize it myself for the ease of use. I checked the model multiple times, not finding technical problem so far. Here are some details. I am using MSELoss in place of QuantileLoss to avoid issues at loss function. Can someone help me figuring out what's causing gradient vanishing in terms of training, modelling or the like?

# config.py
NUM_GPU = torch.cuda.device_count()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
NUM_EPOCHES = 20
NUM_WORKER = 2

# model param
DROPOUT = 0.2
LEARNING_RATE = 0.01
ENCODER_STEPS = 120
DECODER_STEPS = 24+120   # predicting future 24 obs
HIDDEN_LAYER_SIZE = 6
EMBEDDING_DIMENSION = 4
NUM_LSTM_LAYERS = 3
NUM_ATTENTION_HEADS = 2

# train model
criterion = MSELoss(QUANTILES)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print_every_k = 100
losses = []

for epoch in range(NUM_EPOCHES):
    if epoch==NUM_EPOCHES-1:
        print('debug')

    model.train()
    t0 = time.time()
    logger.debug(f"===== Epoch {epoch+1} =========")
    train_epoch_loss = 0.0
    train_running_loss = 0.0

    for i, batch in enumerate(train_dataloader):
        labels = batch['outputs'][:,:,0].float().to(DEVICE)

        # Zero the parameter gradients
        optimizer.zero_grad()

        outputs, attention_weights = model(batch)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        train_running_loss += loss.item()*(labels.shape[0]*labels.shape[1])
        train_epoch_loss += loss.item()*(labels.shape[0]*labels.shape[1])

        if (i+1) % print_every_k == 0:
            logger.debug(f"Mini-batch {i+1} average loss: {round(train_running_loss/len(labels)/(DECODER_STEPS-ENCODER_STEPS)/print_every_k, 5)}")
            train_running_loss = 0.0

    for name, param in model.named_parameters():
        layer_name = name.split('.')[0]
        others_names = '.'.join(name.split('.')[1:])
        if param.grad is not None:
            layer_name = name.split('.')[0]
            others_names = '.'.join(name.split('.')[1:])
            if layer_name == 'output':
                logger.debug(f'{others_names}: {param.grad}')

Some of the print is as follows:

| INFO     | __main__:tft_pipe:56 - 18533 model params, 139588 samples.
| DEBUG    | __main__:tft_pipe:79 - ===== Epoch 1 =========
2023-10-28 23:25:55.885 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-0.1970, -0.1778, -0.0316, -0.1864, -0.2094, -0.1853]],
       device='cuda:0')
2023-10-28 23:25:55.891 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.2512], device='cuda:0')
2023-10-28 23:25:55.893 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 184.34 seconds
2023-10-28 23:25:55.894 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00959
2023-10-28 23:26:18.379 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00603
2023-10-28 23:26:18.384 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 2 =========
2023-10-28 23:29:26.948 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[0.3031, 0.2802, 0.0672, 0.1487, 0.1613, 0.1379]], device='cuda:0')
2023-10-28 23:29:26.953 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1593], device='cuda:0')
2023-10-28 23:29:26.955 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 188.57 seconds
2023-10-28 23:29:26.957 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00492
2023-10-28 23:29:46.950 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00484
2023-10-28 23:29:46.954 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 3 =========
2023-10-28 23:32:57.256 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[ 0.1013,  0.1008,  0.0258, -0.0917, -0.0930, -0.0923]],
       device='cuda:0')
2023-10-28 23:32:57.267 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.5331], device='cuda:0')
2023-10-28 23:32:57.269 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.31 seconds
2023-10-28 23:32:57.271 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00474
2023-10-28 23:33:18.197 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00524
2023-10-28 23:33:18.203 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 4 =========
2023-10-28 23:36:28.117 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-0.3200, -0.2964,  0.0014, -0.2076, -0.2188, -0.1801]],
       device='cuda:0')
2023-10-28 23:36:28.124 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.0208], device='cuda:0')
2023-10-28 23:36:28.127 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 189.92 seconds
2023-10-28 23:36:28.129 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.0047
2023-10-28 23:36:48.394 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00477
2023-10-28 23:36:48.399 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 5 =========
2023-10-28 23:39:58.502 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[ 0.1297,  0.1257,  0.0035, -0.0225, -0.0231, -0.0268]],
       device='cuda:0')
2023-10-28 23:39:58.513 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.3566], device='cuda:0')
2023-10-28 23:39:58.515 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.12 seconds
2023-10-28 23:39:58.518 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00464
2023-10-28 23:40:18.631 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00512
2023-10-28 23:40:18.636 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 6 =========
2023-10-28 23:43:29.830 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[ 0.0386,  0.0408, -0.0006, -0.0571, -0.0618, -0.0497]],
       device='cuda:0')
2023-10-28 23:43:29.840 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.2770], device='cuda:0')
2023-10-28 23:43:29.843 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 191.21 seconds
2023-10-28 23:43:29.845 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00456
2023-10-28 23:43:50.012 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00599
2023-10-28 23:43:50.017 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 7 =========
2023-10-28 23:47:01.444 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-0.0039, -0.0085,  0.0006,  0.0589,  0.0636,  0.0435]],
       device='cuda:0')
2023-10-28 23:47:01.450 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.2083], device='cuda:0')
2023-10-28 23:47:01.452 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 191.44 seconds
2023-10-28 23:47:01.454 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00447
2023-10-28 23:47:21.836 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00467
2023-10-28 23:47:21.843 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 8 =========
2023-10-28 23:50:32.598 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-3.0020e-02, -2.9020e-02, -1.3130e-05, -1.6233e-02, -1.6929e-02,
         -9.7950e-03]], device='cuda:0')
2023-10-28 23:50:32.606 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.0196], device='cuda:0')
2023-10-28 23:50:32.608 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.77 seconds
2023-10-28 23:50:32.610 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00444
2023-10-28 23:50:53.131 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00469
2023-10-28 23:50:53.137 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 9 =========
2023-10-28 23:54:03.649 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[3.5368e-02, 3.3262e-02, 5.8795e-06, 4.1231e-02, 4.7169e-02, 1.7816e-02]],
       device='cuda:0')
2023-10-28 23:54:03.654 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.0711], device='cuda:0')
2023-10-28 23:54:03.656 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.52 seconds
2023-10-28 23:54:03.658 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00439
2023-10-28 23:54:23.823 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00494
2023-10-28 23:54:23.828 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 10 =========
2023-10-28 23:57:36.051 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-0.0286, -0.0357, -0.0002,  0.0101,  0.0245,  0.0015]],
       device='cuda:0')
2023-10-28 23:57:36.059 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.2316], device='cuda:0')
2023-10-28 23:57:36.062 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 192.23 seconds
2023-10-28 23:57:36.063 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00436
2023-10-28 23:57:56.124 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00451
2023-10-28 23:57:56.129 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 11 =========
2023-10-29 00:11:24.993 | DEBUG    | __main__:tft_pipe:115 - Mini-batch 1000 average loss: 0.00425
2023-10-29 00:11:43.654 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-1.1735e-03, -2.6554e-03,  4.9117e-07, -1.3693e-04,  2.1181e-03,
         -1.3661e-03]], device='cuda:0')
2023-10-29 00:11:43.662 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.1520], device='cuda:0')
2023-10-29 00:11:43.664 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.38 seconds
2023-10-29 00:11:43.666 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00427
2023-10-29 00:12:03.697 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00474
2023-10-29 00:00:01.309 | DEBUG    | __main__:tft_pipe:115 - Mini-batch 700 average loss: 0.00436
2023-10-29 00:00:17.189 | DEBUG    | __main__:tft_pipe:115 - Mini-batch 800 average loss: 0.00441
2023-10-29 00:00:33.102 | DEBUG    | __main__:tft_pipe:115 - Mini-batch 900 average loss: 0.00437
2023-10-29 00:00:49.106 | DEBUG    | __main__:tft_pipe:115 - Mini-batch 1000 average loss: 0.00433
2023-10-29 00:01:07.808 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[ 3.8436e-03,  9.6467e-03,  6.3097e-05, -3.9691e-03, -1.3984e-02,
         -8.0875e-05]], device='cuda:0')
2023-10-29 00:01:07.818 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1691], device='cuda:0')
2023-10-29 00:01:07.820 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 191.69 seconds
2023-10-29 00:01:07.825 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00433
2023-10-29 00:01:27.898 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00452
2023-10-29 00:01:27.902 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 12 =========
2023-10-29 00:04:39.496 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-2.6736e-03, -3.2901e-04, -2.3421e-03, -8.6103e-05, -7.5090e-03,
         -9.9114e-04]], device='cuda:0')
2023-10-29 00:04:39.502 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.3272], device='cuda:0')
2023-10-29 00:04:39.504 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 191.6 seconds
2023-10-29 00:04:39.505 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00429
2023-10-29 00:05:00.339 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.0053
2023-10-29 00:05:00.343 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 13 =========
2023-10-29 00:08:12.582 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[ 3.5385e-03, -5.6710e-03, -1.4565e-07,  5.9994e-04,  1.7111e-02,
          5.9326e-03]], device='cuda:0')
2023-10-29 00:08:12.593 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1814], device='cuda:0')
2023-10-29 00:08:12.595 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 192.25 seconds
2023-10-29 00:08:12.596 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00428
2023-10-29 00:08:33.280 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00456
2023-10-29 00:08:33.286 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 14 =========
2023-10-29 00:11:43.654 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-1.1735e-03, -2.6554e-03,  4.9117e-07, -1.3693e-04,  2.1181e-03,
         -1.3661e-03]], device='cuda:0')
2023-10-29 00:11:43.662 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.1520], device='cuda:0')
2023-10-29 00:11:43.664 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.38 seconds
2023-10-29 00:11:43.666 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00427
2023-10-29 00:12:03.697 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00474
2023-10-29 00:12:03.701 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 15 =========
2023-10-29 00:15:15.817 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-1.0530e-04, -2.1193e-04, -1.6541e-05, -2.8272e-04, -3.5081e-03,
         -4.4950e-06]], device='cuda:0')
2023-10-29 00:15:15.823 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1189], device='cuda:0')
2023-10-29 00:15:15.825 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 192.12 seconds
2023-10-29 00:15:15.827 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00427
2023-10-29 00:15:35.848 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00449
2023-10-29 00:15:35.852 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 16 =========
2023-10-29 00:18:47.608 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-5.6954e-04, -7.8295e-04,  8.9046e-06, -2.4048e-05,  4.6034e-03,
         -1.6678e-04]], device='cuda:0')
2023-10-29 00:18:47.621 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.2328], device='cuda:0')
2023-10-29 00:18:47.624 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 191.77 seconds
2023-10-29 00:18:47.626 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00427
2023-10-29 00:19:08.119 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.0045
2023-10-29 00:19:08.123 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 17 =========
2023-10-29 00:22:19.299 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-1.1427e-05,  5.2137e-04,  3.7489e-06,  2.4374e-04,  1.9881e-03,
         -3.9823e-05]], device='cuda:0')
2023-10-29 00:22:19.304 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([-0.2279], device='cuda:0')
2023-10-29 00:22:19.306 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 191.18 seconds
2023-10-29 00:22:19.307 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00425
2023-10-29 00:22:39.591 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00461
2023-10-29 00:22:39.596 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 18 =========
2023-10-29 00:25:52.057 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-5.8683e-05, -8.8668e-04, -5.4194e-06, -6.6186e-05,  1.6081e-03,
          8.8617e-05]], device='cuda:0')
2023-10-29 00:25:52.065 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1201], device='cuda:0')
2023-10-29 00:25:52.066 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 192.47 seconds
2023-10-29 00:25:52.067 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00425
2023-10-29 00:26:13.248 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00453
2023-10-29 00:26:13.254 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 19 =========
2023-10-29 00:29:23.420 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-4.8534e-04, -2.9827e-04, -4.2051e-06, -2.5660e-04, -6.4457e-03,
         -5.5604e-05]], device='cuda:0')
2023-10-29 00:29:23.428 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1287], device='cuda:0')
2023-10-29 00:29:23.431 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 190.18 seconds
2023-10-29 00:29:23.433 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00423
2023-10-29 00:29:43.669 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00479
2023-10-29 00:29:43.676 | DEBUG    | __main__:tft_pipe:79 - ===== Epoch 20 =========
2023-10-29 00:32:56.241 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-4.9815e-05, -7.2576e-05,  1.1909e-06, -3.8043e-05, -2.6047e-03,
         -5.8518e-07]], device='cuda:0')
2023-10-29 00:32:56.247 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.1311], device='cuda:0')
2023-10-29 00:32:56.249 | DEBUG    | __main__:tft_pipe:141 - Epoch trained for 192.57 seconds
2023-10-29 00:32:56.251 | DEBUG    | __main__:tft_pipe:142 - Epoch Train loss: 0.00423
2023-10-29 00:33:16.435 | DEBUG    | __main__:tft_pipe:163 - Epoch Val loss: 0.00453

Updates: I could reach a better-normed output gradient with learning_rate=0.0001, which looks like this:

2023-10-30 22:23:57.374 | DEBUG    | __main__:tft_pipe:132 - module.weight: tensor([[-0.1125, -0.1421, -0.1363,  0.1359, -0.1395, -0.1524, -0.1319, 
 0.1499,
         -0.1268, -0.1586]], device='cuda:0')
2023-10-30 22:23:57.381 | DEBUG    | __main__:tft_pipe:132 - module.bias: tensor([0.0319], device='cuda:0')

Even so, it doesn't resolve the final problem, underfitting. From the loss, one can see that there is almost no reduction upon Epoch 1. What's the more, the prediction result looks no difference from a straight line, meaning the prediction magnitude is far less than the real one. If the cause is not from gradient vanishing, what contributes to such underfitting?

Please note that the raw data is normalized so the scale in figure and scale in training doesn't match, but it doesn't change the underfitting. I can also assure that it is not the case of insufficient factors, as I've tried on synthetic data with future data computed from the past information. Such flat prediction persists. The grey line is attention score, from which style I learned from pytorch-forecasting package.One prediction result

Updates:Noted that gradient vanishing may not happen on output layer, and that the flat line prediction looks so much like a zero-feature OLS, which could only predict mean, I looked into norm of all layers, gradients of weights and bias. It seems that it is the other layer have norm of magnitude e-5 or less. Could it be gradient vanishing on other layers. What could've been the cause of that?

0

There are 0 answers