How to disable model/weight serialization fully with AllenNLP settings?

106 views Asked by At

I wish disable serializing all model/state weights in standard AllenNLP model training via the use of jsonnet config files.

The reason for this is that I am running automatic hyperparameter optimization using Optuna. Testing dozens of models fills up a drive pretty quickly. I already have disabled the checkpointer by setting num_serialized_models_to_keep to 0:

trainer +: {
    checkpointer +: {
        num_serialized_models_to_keep: 0,
    },

I do not wish to set serialization_dir to None as I still want the default behaviour regarding logging of intermediate metrics, etc. I only want to disable the default model state, training state, and best model weights writing.

Besides the option I set above, are there any default trainer or checkpointer options to disable all serialization of model weights? I checked the API docs and webpage but could not find any.

If I need to define the functionality for such an option myself, which base function(s) from the AllenNLP should I override in my Model subclass?

Alternatively, is their any utility for cleaning up intermediate model state when training is concluded?

EDIT: @petew's answer shows the solution for a custom checkpointer, but I am not clear on how to make this code findable to allennlp train for my use-case.

I wish to make the custom_checkpointer callable from a config file as below:

trainer +: {
    checkpointer +: {
        type: empty,
    },

What would be best practice to load the checkpointer when calling allennlp train --include-package <$my_package>?

I have my_package with submodules in subdirectories such as my_package/modelss and my_package/training. I would like to place the custom checkpointer code in my_package/training/custom_checkpointer.py My main model is located in my_package/models/main_model.py. Do I have to edit or import any code/functions in my main_model class to use the custom checkpointer?

1

There are 1 answers

3
petew On BEST ANSWER

You could create and register a custom Checkpointer that basically just does nothing:

@Checkpointer.register("empty")
class EmptyCheckpointer(Registrable):
    def maybe_save_checkpoint(
        self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int
    ) -> None:
        pass

    def save_checkpoint(
        self,
        epoch: Union[int, str],
        trainer: "allennlp.training.trainer.Trainer",
        is_best_so_far: bool = False,
        save_model_only=False,
    ) -> None:
        pass

    def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]:
        pass

    def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        return {}, {}

    def best_model_state(self) -> Dict[str, Any]:
        return {}