How to convert TensorFlow checkpoint files to TensorFlowJS?

743 views Asked by At

I have a project that was developed on TensorFlow v1 I think. It works in Python 3.8 like this:

 ...
 saver = tf.train.Saver(var_list=vars)
 ...
 saver.restore(self.sess, tf.train.latest_checkpoint(checkpoint_dir))
 ...

The checkpoint files reside in the "checkpoint_dir"

I would like to use this with TFjs but I can't figure out how to transform the checkpoint files to something that can be loaded with TFjs.

What should I do?

thanks,

John

1

There are 1 answers

0
coding-dude.com On

Ok, I figured it out. Hope this helps other beginners like me too.

The checkpoint files do not contain the model, they only contain the values (weights, etc) of the model.

The model is actually built in the code. So, here are the steps to convert the Tensorflow v1 checkpoint files to TensorflowJS loadable model:

  1. First I saved the checkpoint again because there was a file that was missing (.meta file) This contains some meta information about the values in the checkpoint. To save the checkpoint with meta I used this code right after the saver.restore(... call like this:
...
saver.save(self.sess,save_path='./newcheckpoint/')
...
  1. Save the model as a frozen model file like this:
import tensorflow.compat.v1 as tf

meta_path = './newcheckpoint/.meta' # Your .meta file
output_node_names = ['name_of_the_output_node']    # Output nodes

with tf.Session() as sess:
    # Restore the graph
    saver = tf.train.import_meta_graph(meta_path)

    # Load weights
    saver.restore(sess,tf.train.latest_checkpoint('./newcheckpoint/'))

    # Freeze the graph
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    # Save the frozen graph
    with open('./freeze/output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

This will save the model to ./freeze/output_graph.pb

  1. Using tensorflowjs_converter convert the frozen model to a web model like this:

tensorflowjs_converter --input_format=tf_frozen_model --output_node_names='final_add' --skip_op_check ./freeze/output_graph.pb ./web_model/

Had to use the --skip_op_check due to some missing op errors/warnings when trying to convert.

As a result of step 3, the ./webmodel/ folder will contain the JSON and binary files required by the TensorflowJS library.

Here's how I load the model using tfjs 2.x:

model=await tf.loadGraphModel('web_model/model.json');