HuggingFace BetterTransformer in `with` context - cannot disable after context

71 views Asked by At

I am writing a custom with context manager to temporarily make the model a BetterTransformer model while calling trainer.evaluate().

I evaluated before, in, and after the with context. I noticed that the evaluation after the with context still uses BetterTransformer. This is a problem because the trainer.train() call afterwards will also use BetterTransformer, resulting in poor training due to padding.

How do I create a custom with context that only uses BetterTransformer inside the context, not afterwards?

Please find the MWE gist here.

I created a custom context manager:

class BetterTransformerContext:
    """Temporarily replace a model with a BetterTransformer model."""

    def __init__(self, model):
        self.model = model
        self.original_model = None

    def __enter__(self):
        self.original_model = self.model
        self.model = BetterTransformer.transform(self.model)
        return self.model

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.model = self.original_model
        # self.model = BetterTransformer.reverse(self.model)  # NOTE: same result

The output is as follows. Evaluating without BetterTransformer handles approximately 100 it/s, with BetterTransformer handles approximately 115 it/s. As you can see, evaluating after the context still results in 115 it/s.

========== Without Optimum (-> should be slow) ==========
BT before context:  False
100%|█████████████████████████████| 204/204 [00:01<00:00, 103.09it/s]
0.3161764705882353
========== With Optimum (-> should be fast) ==========
BT in context:  True
100%|█████████████████████████████| 204/204 [00:01<00:00, 116.68it/s]
0.3161764705882353
========== Without Optimum (-> should be slow) ==========
BT after context:  True
100%|█████████████████████████████| 204/204 [00:01<00:00, 116.53it/s]
0.3161764705882353
1

There are 1 answers

0
Arthur Thuy On BEST ANSWER

I found a solution by using a custom context manager on the trainer object, as opposed to applying it on a model object.

The custom context manager is as follows:

class BetterTransformerTrainerContext:
    """Context manager to wrap trainer.model with BetterTransformer."""
    def __init__(self, trainer):
        self.trainer = trainer

    def __enter__(self):
        self.trainer.model = BetterTransformer.transform(
            self.trainer.model, keep_original_model=True
        )
        return self.trainer

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.trainer.model = BetterTransformer.reverse(self.trainer.model)

It can be used as follows:

print("=" * 10, "With Optimum (-> should be fast)", "=" * 10)
with BetterTransformerTrainerContext(trainer) as _optimum_trainer:
    eval_accuracy = _optimum_trainer.evaluate()["eval_accuracy"]
    print(eval_accuracy)

I hope this might be helpful to someone else.