Tensorflow: how to save a graph with pruned weights

976 views Asked by At

I've applied iterative pruning to my neural network as in this work to reduce my network size with about 90%. The idea is to remove edges that are not important to my network (weights close to zero). I perform the following code to save the model and computation graph (same code as in train.py), where sparse_w is a dict of variable names that we wish to store (ie all variables not in this dict are pruned and can be thrown away):

# Save model objects to serialized format
final_saver = tf.train.Saver(sparse_w)
final_saver.save(sess, "model_ckpt_sparse_retrained")

# Save graph, this is probably where I go wrong
tf.train.write_graph(sess.graph_def, '.', "my_graph.pb", as_text=False)

I run into trouble when I try to freeze my graph for inference. Or in general, when I load the saved graph + model. When I load the saved graph to freeze it:

from tensorflow.python.tools import freeze_graph
import tensorflow as tf

freeze_graph.freeze_graph(input_graph="my_graph.pb",
                          input_saver="",
                          input_binary=True,
                          input_checkpoint="model_ckpt_sparse_retrained",
                          output_node_names="y_",
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph="frozen_graph.pb",
                          clear_devices=True,
                          initializer_nodes="")

I get an error: Attempting to use uninitialized value Variable_2. I suppose this is expected and comes from tensorflow trying to use the weights that I didn't save. This is likely, because I do not get this error for the dense model (with all weight variables saved).

So I suppose I need a way to construct a new graph without the pruned weights, and then save that graph to a pb file. But how do I do that?

0

There are 0 answers