What is the reason for MultiHeadAttention having a different call convention than Attention and AdditiveAttention?

95 views Asked by At

Attention and AdditiveAttention are called with their input tensors in a list. (same as Add, Average, Concatenate, Dot, Maximum, Multiply, Subtract)

But MultiHeadAttention is called by passing the input tensors as separate arguments.

The following minimal example shows the difference in how the inbound_nodes are linked:

import json

from tensorflow.keras.layers import AdditiveAttention, Attention, Input, MultiHeadAttention
from tensorflow.keras.models import Model

inputs = [Input(shape=(8, 16)), Input(shape=(4, 16))]
outputs = [Attention()([inputs[0], inputs[1]]),
           AdditiveAttention()([inputs[0], inputs[1]]),
           MultiHeadAttention(num_heads=2, key_dim=2)(inputs[0], inputs[1])]
model = Model(inputs=inputs, outputs=outputs)

print(json.dumps(json.loads(model.to_json()), indent=4))
[...]
                "name": "attention",
                "inbound_nodes": [
                    [
                        [
                            "input_1",
                            0,
                            0,
                            {}
                        ],
                        [
                            "input_2",
                            0,
                            0,
                            {}
                        ]
                    ]
                ]
[...]
                "name": "additive_attention",
                "inbound_nodes": [
                    [
                        [
                            "input_1",
                            0,
                            0,
                            {}
                        ],
                        [
                            "input_2",
                            0,
                            0,
                            {}
                        ]
                    ]
                ]
[...]
                "name": "multi_head_attention",
                "inbound_nodes": [
                    [
                        [
                            "input_1",
                            0,
                            0,
                            {
                                "value": [
                                    "input_2",
                                    0,
                                    0
                                ]
                            }
                        ]
                    ]
                ]
[...]

What's the reason for MultiHeadAttention not following the convention of the other two attention layers, and what does it mean for input_2 being stored as value in input_1 in the inbound_nodes?

(Some context on why this is relevant to me: I'm maintaining this library and would like to implement support for MultiHeadAttention.)

0

There are 0 answers