How to extend Keras GPT2 model (MoE example)

204 views Asked by At

I was playing around with Keras GPT2 model - in an attempt to make a Mixture of Experts and achieve agi.

Link to Keras docs: https://keras.io/api/keras_nlp/models/gpt2/

Final Edit

Got it to work properly. The code below works. Feel free to leave any feedback or improvements. I feel the agi.

Some thoughts - the gating network does not need time distributed as dense layers now support 3d tensors. However, I have no idea how big this network should be for a base gpt2 model with 2, 4, etc. experts.

Also, seems like this implementation - does not return choices per query. Maybe that wasn't a thing when it was implemented.

Lots of issues I think were happening because I was low on memory on top of all the bugs.

Edit 2

Running this in Colab gives another clue. I don't understand why the loss expects values between [0,768]. The token id values are 0 to max vocab.

Received a label value of 50256 which is outside the valid range of [0, 768).  Label values: 31373 11 703 389 345 30 50256 0 0 0 0... 

The problem here was that I called the backbone model in the gpt layer instead of GPT2CausalLM. The first must be used for something else.

Edit 1

My general question is - what is the best way to chain or extend Keras GPT model i.e.: to implement a bigger model such as MoE.

Here is the updated and working code:

import tensorflow as tf
import keras_nlp


def create_gating_network(sequence_length, num_experts, feature_dim=768):
    inputs = tf.keras.layers.Input(shape=(sequence_length, feature_dim))
    x = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(64, activation="relu"))(
        inputs
    )
    outputs = tf.keras.layers.TimeDistributed(
        tf.keras.layers.Dense(num_experts, activation="softmax")
    )(x)
    gating_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return gating_model


def moe_function(args):
    expert_outputs, gating_coefficients = args
    weighted_experts = expert_outputs * gating_coefficients
    intermediate_sum = tf.reduce_sum(weighted_experts, axis=2)
    weighted_sum = tf.reduce_sum(intermediate_sum, axis=2)
    return weighted_sum


class ExpertGPT2Layer(tf.keras.layers.Layer):
    def __init__(self, name="gpt2_base_en", sequence_length=128, **kwargs):
        super(ExpertGPT2Layer, self).__init__(**kwargs)
        self.sequence_length = sequence_length
        self.preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
            name, sequence_length=sequence_length
        )
        self.gpt2_model = keras_nlp.models.GPT2CausalLM.from_preset(
            name,
            preprocessor=self.preprocessor,
        )

    def call(self, inputs, training=False):
        preprocess = self.preprocessor(inputs)
        outputs = self.gpt2_model(preprocess[0], training=True)
        return outputs


class CustomGPT2Model(tf.keras.Model):
    def __init__(
        self,
        gating_network,
        name="gpt2_base_en",
        sequence_length=128,
        feature_dim=768,
        num_experts=4,
        **kwargs
    ):
        super(CustomGPT2Model, self).__init__(**kwargs)
        self.sequence_length = sequence_length
        self.feature_dim = feature_dim
        self.num_experts = num_experts
        self.tokenizer = keras_nlp.models.GPT2Tokenizer.from_preset(name)
        self.preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
            name, sequence_length=sequence_length
        )
        self.expert_layers = [
            ExpertGPT2Layer(sequence_length=sequence_length, name=name)
            for _ in range(num_experts)
        ]
        self.gating_network = gating_network

    def apply_expert(self, expert, inputs, training):
        result = expert(inputs, training=training)
        return result

    def build(self, input_shape):
        inputs = tf.keras.layers.Input(
            shape=input_shape, dtype=tf.string, name="text-input"
        )

        # Preprocessor returns x, y, w
        # https://github.com/keras-team/keras-nlp/blob/v0.6.2/keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py#L127
        x, labels, w = self.preprocessor(inputs)
        time_dim_token_ids = tf.expand_dims(x["token_ids"], axis=-1)
        replicated_token_ids = tf.tile(time_dim_token_ids, [1, 1, self.feature_dim])

        # Compute expert predictions
        expert_outputs = [
            self.apply_expert(expert, inputs, training=True)
            for expert in self.expert_layers
        ]
        stacked_expert_outputs = tf.stack(expert_outputs, axis=1)

        # Compute gating coefficients
        gating_coefficients = self.gating_network(replicated_token_ids)
        expanded_gating_coefficients = tf.expand_dims(
            tf.expand_dims(gating_coefficients, axis=-1), axis=-1
        )

        moe_output = moe_function(
            [stacked_expert_outputs, expanded_gating_coefficients]
        )
        self.model = tf.keras.Model(inputs=inputs, outputs=[moe_output, labels])
        super(CustomGPT2Model, self).build(input_shape)

    def call(self, inputs, training=False):
        return self.model(inputs, training)

    @tf.function
    def train_step(self, data):
        x = data

        with tf.GradientTape() as tape:
            y_pred, y_true = self.model(x, training=True)
            loss = self.compiled_loss(y_true, y_pred, regularization_losses=self.losses)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        self.compiled_metrics.update_state(y_true, y_pred)
        return {m.name: m.result() for m in self.metrics}


def main():
    text = ["hello, how are you?", "I am good"]
    batch_size = 1
    num_experts = 2
    sequence_length = 64

    dataset = tf.data.Dataset.from_tensor_slices(text)
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    gating_network = create_gating_network(sequence_length, num_experts)
    moe_model = CustomGPT2Model(
        gating_network, sequence_length=sequence_length, num_experts=num_experts
    )

    moe_model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=tf.keras.optimizers.Adam(2e-5),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    moe_model.build(input_shape=(1,))
    moe_model.summary()
    moe_model.fit(dataset, epochs=3, verbose=1)


if __name__ == "__main__":
    main()
0

There are 0 answers