Using .generate function for beam search over predictions in custom model extending TFPreTrainedModel class

595 views Asked by At

I want to use .generate() functionality of hugging face in my model's predictions. My model is a custom model inehriting from "TFPreTrainedModel" class and has a custom transformer inheriting from tf.keras.layers followed by few hidden layers and a final dense layer (inherited from tf.keras.layers).

I am not able to use .generate() inspite of adding get_lm_head() function (as given here https://huggingface.co/docs/transformers/main_classes/model) and returning my last dense layer in it. When I call .generate() it throws TypeError: The current model class (NextCateModel) is not compatible with.generate(), as it doesn't have a language model head.

Can anyone suggest on how to use .generate() functionality of huggingface in our custom transformer based models without using the huggingface's list of pre-trained models?

PS: It checks for models among huggingface pretrained ones which are defined in their generation_tf_utils.py

generate_compatible_mappings = [
                TF_MODEL_FOR_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
                TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
                TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
            ]

I donot intend to use their pretrained models given in above mappings (one of them is shown below)

 TF_MODEL_FOR_CAUSAL_LM_MAPPING=
        ("bert", "TFBertLMHeadModel"),
        ("camembert", "TFCamembertForCausalLM"),
        ("ctrl", "TFCTRLLMHeadModel"),
        ("gpt2", "TFGPT2LMHeadModel"),
        ("gptj", "TFGPTJForCausalLM"),
        ("openai-gpt", "TFOpenAIGPTLMHeadModel"),
        ("opt", "TFOPTForCausalLM"),
        ("rembert", "TFRemBertForCausalLM"),
        ("roberta", "TFRobertaForCausalLM"),
        ("roformer", "TFRoFormerForCausalLM"),
        ("transfo-xl", "TFTransfoXLLMHeadModel"),
        ("xglm", "TFXGLMForCausalLM"),
        ("xlm", "TFXLMWithLMHeadModel"),
        ("xlnet", "TFXLNetLMHeadModel"),
  1340             if generate_compatible_classes:
   1341                 exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
-> 1342             raise TypeError(exception_message)
0

There are 0 answers