I am writing a custom with
context manager to temporarily make the model a BetterTransformer model while calling trainer.evaluate()
.
I evaluated before, in, and after the with
context. I noticed that the evaluation after the with
context still uses BetterTransformer. This is a problem because the trainer.train()
call afterwards will also use BetterTransformer, resulting in poor training due to padding.
How do I create a custom with
context that only uses BetterTransformer inside the context, not afterwards?
Please find the MWE gist here.
I created a custom context manager:
class BetterTransformerContext:
"""Temporarily replace a model with a BetterTransformer model."""
def __init__(self, model):
self.model = model
self.original_model = None
def __enter__(self):
self.original_model = self.model
self.model = BetterTransformer.transform(self.model)
return self.model
def __exit__(self, exc_type, exc_val, exc_tb):
self.model = self.original_model
# self.model = BetterTransformer.reverse(self.model) # NOTE: same result
The output is as follows. Evaluating without BetterTransformer handles approximately 100 it/s, with BetterTransformer handles approximately 115 it/s. As you can see, evaluating after the context still results in 115 it/s.
========== Without Optimum (-> should be slow) ==========
BT before context: False
100%|█████████████████████████████| 204/204 [00:01<00:00, 103.09it/s]
0.3161764705882353
========== With Optimum (-> should be fast) ==========
BT in context: True
100%|█████████████████████████████| 204/204 [00:01<00:00, 116.68it/s]
0.3161764705882353
========== Without Optimum (-> should be slow) ==========
BT after context: True
100%|█████████████████████████████| 204/204 [00:01<00:00, 116.53it/s]
0.3161764705882353
I found a solution by using a custom context manager on the
trainer
object, as opposed to applying it on amodel
object.The custom context manager is as follows:
It can be used as follows:
I hope this might be helpful to someone else.