Visualizing self attention weights for sequence addition problem with LSTM?

350 views Asked by At

I am using Self Attention layer from here for a simple problem of adding all the numbers in a sequence that come before a delimiter. With training, I expect the neural network to learn which numbers to add and using Self Attention layer, I expect to visualize where the model is focusing. The code to reproduce the the results is following

import os
import sys

import matplotlib.pyplot as plt
import numpy
import numpy as np
from keract import get_activations
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.layers import Dense, Dropout, LSTM

from attention import Attention  #

def add_numbers_before_delimiter(n: int, seq_length: int, delimiter: float = 0.0,
                                         index_1: int = None) -> (np.array, np.array):
    Task: Add all the numbers that come before the delimiter.
    x = [1, 2, 3, 0, 4, 5, 6, 7, 8, 9]. Result is y =  6.
    @param n: number of samples in (x, y).
    @param seq_length: length of the sequence of x.
    @param delimiter: value of the delimiter. Default is 0.0
    @param index_1: index of the number that comes after the first 0.
    @return: returns two numpy.array x and y of shape (n, seq_length, 1) and (n, 1).
    x = np.random.uniform(0, 1, (n, seq_length))
    y = np.zeros(shape=(n, 1))
    for i in range(len(x)):
        if index_1 is None:
            a = np.random.choice(range(1, len(x[i])), size=1, replace=False)
            a = index_1
        y[i] =  np.sum(x[i, 0:a])
        x[i, a] = delimiter

    x = np.expand_dims(x, axis=-1)
    return x, y

def main():

    # data. definition of the problem.
    seq_length = 20
    x_train, y_train = add_numbers_before_delimiter(20_000, seq_length)
    x_val, y_val = add_numbers_before_delimiter(4_000, seq_length)

    # just arbitrary values. it's for visual purposes. easy to see than random values.
    test_index_1 = 4
    x_test, _ = add_numbers_before_delimiter(10, seq_length, 0, test_index_1)
    # x_test_mask is just a mask that, if applied to x_test, would still contain the information to solve the problem.
    # we expect the attention map to look like this mask.
    x_test_mask = np.zeros_like(x_test[..., 0])
    x_test_mask[:, test_index_1:test_index_1 + 1] = 1

    model = Sequential([
        LSTM(100, input_shape=(seq_length, 1), return_sequences=True),
        Dense(1, activation='linear')

    model.compile(loss='mse', optimizer='adam')

    output_dir = 'task_add_two_numbers'
    if not os.path.exists(output_dir):

    max_epoch = int(sys.argv[1]) if len(sys.argv) > 1 else 200

    class VisualiseAttentionMap(Callback):

        def on_epoch_end(self, epoch, logs=None):
            attention_map = get_activations(model, x_test, layer_names='attention_weight')['attention_weight']

            # top is attention map.
            # bottom is ground truth.
            plt.imshow(np.concatenate([attention_map, x_test_mask]), cmap='hot')

            iteration_no = str(epoch).zfill(3)
            plt.title(f'Iteration {iteration_no} / {max_epoch}')
            plt.clf(), y_train, validation_data=(x_val, y_val), epochs=max_epoch,
              batch_size=64, callbacks=[VisualiseAttentionMap()])

if __name__ == '__main__':

However, I get following results attention weights

[Please click the link]1 to view weights during training.

I expect the attention to focus on all values before the delimiter. The white below represents ground truth while the upper half part represents weights for 10 samples.


There are 0 answers