How can I add tf.keras.layers.AdditiveAttention in my model?

1.8k views Asked by At

I am working on a machine language translation problem. The Model I am using is:

    Model = Sequential([
          Embedding(english_vocab_size, 256, input_length=english_max_len, mask_zero=True),
          LSTM(256, activation='relu'),
          RepeatVector(german_max_len),
          LSTM(256, activation='relu', return_sequences=True),
          Dense(german_vocab_size, activation='softmax')
    ])

Here,english_vocab_size and english_max_len are the total number of english words in the english vocabulory and number of words in each english sentence respectively. And the same is with german_vocab_size and german_max_len.

Now, how can I add tf.keras.layers.AdditiveAttention layer in this Model?

Edit - I tried a lot to find good tutorials of implementing tf.keras.layers.AdditiveAttention layer on an nlp task, but couldn't find any. So, I think if someone can explain how can I put the tf.keras.layers.AdditiveAttention layer in this model, the person would be the first person to give a very clear explanation on how to use tf.keras.layers.AdditiveAttention as it would be then very clear implementation on how to use the tf.keras.layers.AdditiveAttention layer !

2

There are 2 answers

0
ML85 On

This will help you from the previous link

How to build a attention model with keras?

context_vector, attention_weights = Attention(32)(lstm, state_h)

or

This is how to use Luong-style attention:

attention = tf.keras.layers.Attention()([query, value])

And Bahdanau-style attention :

attention = tf.keras.layers.AdditiveAttention()([query, value])

The adapted version:

weights = tf.keras.layers.Attention()([lstm, state_h])
0
taichi_tiger On

The first thing you should know is that you must know the input for your attention layer tf.keras.layers.AdditiveAttention(). And then you must know how to use the output from your attention layer. If these two points are clear, I don't think it is difficult for you to build your model with an attention layer.