How to restore a specific checkpoint in tensorflow2 (to implement early stopping)?

2.3k views Asked by At

I used the following code to create a checkpoint manager outside of the loop that I train my model:

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(object_1=object_1)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

Then while training the model, I use ckpt_save_path = ckpt_manager.save() to save the variables after each epoch.

Given that I want to implement an early stopping approach, I need to restore all the variables after a specific epoch and use those variables to do a prediction. How can I restore the variable after let's say epoch e if I have used the above code to save the variables (hope the saving process is correct?). I know I can first create the same variables and objects and then use the following code to restore the latest checkpoint, but have no idea how to restore specific checkpoints (like the variables after epoch number e) and not the latest.

ckpt.restore(ckpt_manager.latest_checkpoint).assert_consumed()

Thanks,

1

There are 1 answers

1
EyesBear On

Yes, you need to generate a text string of file name with epoch number.

c_manager = tf.train.CheckpointManager(checkpoint, ...)

if EPOCH == '':
    if c_manager.latest_checkpoint:
        tf.print("-----------Restoring from {}-----------".format(
            c_manager.latest_checkpoint))
        checkpoint.restore(c_manager.latest_checkpoint)
        EPOCH = c_manager.latest_checkpoint.split(sep='ckpt-')[-1]
    else:
        tf.print("-----------Initializing from scratch-----------")
else:    
    checkpoint_fname = CHECKPOINT_SAVE_DIR + 'ckpt-' + str(EPOCH)
    tf.print("-----------Restoring from {}-----------".format(checkpoint_fname))
    checkpoint.restore(checkpoint_fname)