I was writing my own callback to stop training based on some custom condition. EarlyStopping has this to stop the training once condition is met:
self.model.stop_training = True
e.g. from https://www.tensorflow.org/guide/keras/custom_callback
class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments: patience: Number of epochs to wait after min has been hit. After this number of no improvement, training stops. """
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
The thing is, it doesn't work for tensorflow 2.2 and 2.3. Any idea for workaround? How else can one stop the training of a model in tf 2.3?
I copied your code and added a few print statements to see what is going on. I also changed the loss being monitored from training loss to validation loss because training loss tends to keep decreasing over many epochs while validation loss tends to level out faster. Better to monitor validation loss for early stopping and for saving weights then to use training loss. Your code runs fine and does stop training if the loss does not reduce after patience number of epochs. Make sure you have the code below
Here is your code modified with print statements so you can see what is going on
I copied your new code and ran it. Apparently tensorflow does not evaluate model.stop_training during batches. So even though model.stop_training gets set to True in on_train_batch_end it continues processing the batches until all batches for the epoch are completed. Then at the end of the epoch tensorflow evaluates model.stop_training and training does stop.