Callbacks in tensorflow 2.3

435 views Asked by At

I was writing my own callback to stop training based on some custom condition. EarlyStopping has this to stop the training once condition is met:

self.model.stop_training = True

e.g. from https://www.tensorflow.org/guide/keras/custom_callback

class EarlyStoppingAtMinLoss(keras.callbacks.Callback): """Stop training when the loss is at its min, i.e. the loss stops decreasing.

Arguments: patience: Number of epochs to wait after min has been hit. After this number of no improvement, training stops. """

def __init__(self, patience=0):
    super(EarlyStoppingAtMinLoss, self).__init__()
    self.patience = patience
    # best_weights to store the weights at which the minimum loss occurs.
    self.best_weights = None

def on_train_begin(self, logs=None):
    # The number of epoch it has waited when loss is no longer minimum.
    self.wait = 0
    # The epoch the training stops at.
    self.stopped_epoch = 0
    # Initialize the best as infinity.
    self.best = np.Inf

def on_epoch_end(self, epoch, logs=None):
    current = logs.get("loss")
    if np.less(current, self.best):
        self.best = current
        self.wait = 0
        # Record the best weights if current results is better (less).
        self.best_weights = self.model.get_weights()
    else:
        self.wait += 1
        if self.wait >= self.patience:
            self.stopped_epoch = epoch
            self.model.stop_training = True
            print("Restoring model weights from the end of the best epoch.")
            self.model.set_weights(self.best_weights)

def on_train_end(self, logs=None):
    if self.stopped_epoch > 0:
        print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

The thing is, it doesn't work for tensorflow 2.2 and 2.3. Any idea for workaround? How else can one stop the training of a model in tf 2.3?

2

There are 2 answers

3
Gerry P On

I copied your code and added a few print statements to see what is going on. I also changed the loss being monitored from training loss to validation loss because training loss tends to keep decreasing over many epochs while validation loss tends to level out faster. Better to monitor validation loss for early stopping and for saving weights then to use training loss. Your code runs fine and does stop training if the loss does not reduce after patience number of epochs. Make sure you have the code below

patience=3 # set patience value
callbacks=[EarlyStoppingAtMinLoss(patience)]
# in model.fit include callbacks=callbacks

Here is your code modified with print statements so you can see what is going on

class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("val_loss")
        print('epoch = ', epoch +1, '   loss= ', current, '   best_loss = ', self.best, '   wait = ', self.wait)
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            print ( ' loss improved setting wait to zero and saving weights')
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            print ( ' for epoch ', epoch +1, '  loss did not improve setting wait to ', self.wait)
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

I copied your new code and ran it. Apparently tensorflow does not evaluate model.stop_training during batches. So even though model.stop_training gets set to True in on_train_batch_end it continues processing the batches until all batches for the epoch are completed. Then at the end of the epoch tensorflow evaluates model.stop_training and training does stop.

0
marbohu On

Thank you, the code works as it is and explains what's hapenning inside. I wanted to convert this callback to batch version though.

class EarlyStoppingAtMinLoss(tf.keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_batch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_train_batch_end(self, batch, logs=None):
        current = logs.get("loss")
        print('batch = ', batch +1, '   loss= ', current, '   best_loss = ', self.best, '   wait = ', self.wait)
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            print ( ' loss improved setting wait to zero and saving weights')
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            print ( ' for batch ', batch +1, '  loss did not improve setting wait to ', self.wait)
            print('wait:', self.wait)
            print('patience:', self.patience)
            if self.wait >= self.patience:
                self.stopped_batch = batch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best batch.")
                self.model.set_weights(self.best_weights)


    def on_train_end(self, logs=None):
        if self.stopped_batch > 0:
            print("Batch %05d: early stopping" % (self.stopped_batch + 1))

What I get on some data is this:

batch = 42 loss= 709.771484375 best_loss = 27.087162017822266 wait = 40 for batch 42 loss did not improve setting wait to 41 wait: 41 patience: 3 Restoring model weights from the end of the best batch.

As if converting on_epoch_end to on_batch_end made the script ignore this line "self.model.stop_training = True" - it prints that it ends but training goes on. (It's still tf 2.3.0)

Is there any difference between epoch and batch callbacks then?