I have a tf.keras model which internally contains a "custom tf.keras.layers.MultiHeadAttention() layer ". That is, I have divided the multihead attention layers into two parts: (1) a first layer that returns the values of each head and (2) a second layer that takes as input the heads from (1) and returns the final values of the multihead attention layer.
I have done this because I want to manually change the values from the output of (1) and then pass them as input to (2). Is there a way that I can do this with monkey patching? If not, is there any other way?
Suppose that the name of the layer that returns me the values of each head is: "my__multi_head_attention_prev_1. The code below is the one that I'm using for patching. The lines before the definiton of "input1" and "input2" is where I define the initial model after the layer that I'm pathing,and is the part that I'd like to automatize, since I'm currently dividing the full model into two pieces and manually defining the second one.
heads_clean,encoded_patches_clean = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('my__multi_head_attention_prev_1').output).predict(clean_image))
heads_corrupted, encoded_patches_corrupted = tf.keras.models.Model(inputs=model.input, outputs=model.get_layer('my__multi_head_attention_prev_1').output).predict(image_corrupted))
patched_heads = patch(head_to_modify = head_idx, attention_scores = heads_corrupted, replacement_values = heads_clean)
input1 = layers.Input(patched_heads.shape[1:])
input2 = layers.Input(encoded_patches_corrupted.shape[1:])
x2, encoded_patches = My_MultiHeadAttention_last_step(num_heads, projection_dim)(input1, input2)
x3 = layers.LayerNormalization(epsilon = 1e-6)(x2)
x3 = mlp(x3, hidden_units = transformer_units, dropout_rate = 0.1)
encoded_patches = layers.Add()([x3, x2])
representation = layers.LayerNormalization(epsilon = 1e-6)(input2)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
features = mlp(representation, hidden_units = mlp_head_units, dropout_rate = 0.5)
logits = layers.Dense(n_classes)(features)
rest_of_model = tf.keras.Model(inputs = [input1,input2], outputs = logits)
rest_of_model([patched_heads,encoded_patches_corrupted])
Thank you!