Attention and AdditiveAttention are called with their input tensors in a list. (same as Add, Average, Concatenate, Dot, Maximum, Multiply, Subtract)
But MultiHeadAttention is called by passing the input tensors as separate arguments.
The following minimal example shows the difference in how the inbound_nodes are linked:
import json
from tensorflow.keras.layers import AdditiveAttention, Attention, Input, MultiHeadAttention
from tensorflow.keras.models import Model
inputs = [Input(shape=(8, 16)), Input(shape=(4, 16))]
outputs = [Attention()([inputs[0], inputs[1]]),
AdditiveAttention()([inputs[0], inputs[1]]),
MultiHeadAttention(num_heads=2, key_dim=2)(inputs[0], inputs[1])]
model = Model(inputs=inputs, outputs=outputs)
print(json.dumps(json.loads(model.to_json()), indent=4))
[...]
"name": "attention",
"inbound_nodes": [
[
[
"input_1",
0,
0,
{}
],
[
"input_2",
0,
0,
{}
]
]
]
[...]
"name": "additive_attention",
"inbound_nodes": [
[
[
"input_1",
0,
0,
{}
],
[
"input_2",
0,
0,
{}
]
]
]
[...]
"name": "multi_head_attention",
"inbound_nodes": [
[
[
"input_1",
0,
0,
{
"value": [
"input_2",
0,
0
]
}
]
]
]
[...]
What's the reason for MultiHeadAttention not following the convention of the other two attention layers, and what does it mean for input_2 being stored as value in input_1 in the inbound_nodes?
(Some context on why this is relevant to me: I'm maintaining this library and would like to implement support for MultiHeadAttention.)