tensorflow: load checkpoint

4.9k views Asked by At

I've been training a model which looks a bit like:

base_model = tf.keras.applications.ResNet50(weights=weights, include_top=False, input_tensor=input_tensor)

for layer in base_model.layers:
    layer.trainable = False

x = tf.keras.layers.GlobalMaxPool2D()(base_model.output)

output = tf.keras.Sequential()
output.add(tf.keras.layers.Dense(2, activation='linear'))
output.add(tf.keras.layers.Dense(2, activation='linear'))
output.add(tf.keras.layers.Dense(2, activation='linear'))
output.add(tf.keras.layers.Dense(2, activation='linear'))
output.add(tf.keras.layers.Dense(2, activation='linear'))

return output(x)

I setup checkpoints saving with code like:

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True,
    save_freq=batch_size*5)

Yesterday I started a fit to run for 11 epochs. I'm not sure why, but the machine restarted during the 7th epoch. Naturally I want to resume fitting from the start of epoch 7.

The checkpoint code above created three files:

enter image description here

The contents of checkpoint are:

model_checkpoint_path: "checkpoint"
all_model_checkpoint_paths: "checkpoint"

The other two files are binary. I tried to load the checkpoint weights with both:

model.load_weights('./2022-03-16_21-10/checkpoints/checkpoint.data-00000-of-00001')
model.load_weights('./2022-03-16_21-10/checkpoints/')

Both fail with NotFoundError: Unsuccessful TensorSliceReader constructor: Failed to find any matching files.

How can I restore this checkpoint and as a result resume fitting?

I'm using tensorflow 2.4.

1

There are 1 answers

0
elbe On

These might help: Training checkpoints and tf.train.Checkpoint. According to the documentation, you should be able to load the model using something like this:

model = tf.keras.Model(...)
checkpoint = tf.train.Checkpoint(model)
# Restore the checkpointed values to the `model` object.
checkpoint.restore(save_path)

I am not sure it will work if the checkpoint contains other variables. You might have to use checkpoint.restore(path).expect_partial().

You can also check the content that has been saved (according to the documentation) by Manually inspecting checkpoints :

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())