Error in saving and using model of TensorForestEstimator for Android

1.1k views Asked by At

I use the randomforest estimator, implemented in tensorflow, to predict if a text is english or not. I saved my model (A dataset with 2k samples and 2 class labels 0/1 (Not English/English)) using the following code (train_input_fn function return features and class labels):

model_path='test/'
TensorForestEstimator(params, model_dir='model/')
estimator.fit(input_fn=train_input_fn, max_steps=1)

After running the above code, the graph.pbtxt and checkpoints are saved in the model folder. Now I want to use it on Android. I have 2 problems:

  1. As the first step, I need to freeze the graph and checkpoints to a .pb file to use it on Android. I tried freeze_graph (I used the code here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). When I call the freeze_graph in my mode, I get the following error and the code cannot create the final .pb graph:

    File "/Users/XXXXXXX/freeze_graph.py", line 105, in freeze_graph _ = tf.import_graph_def(input_graph_def, name="") File "/anaconda/envs/tensorflow/lib/python2.7/site-packages/tensorflow/python/framework/importer.py", line 258, in import_graph_def op_def = op_dict[node.op] KeyError: u'CountExtremelyRandomStats'

this is how I call freeze_graph:

def save_model_android():
    checkpoint_state_name = "model.ckpt-1"
    input_graph_name = "graph.pbtxt"
    output_graph_name = "output_graph.pb"
    checkpoint_path = os.path.join(model_path, checkpoint_state_name)

    input_graph_path = os.path.join(model_path, input_graph_name)
    input_saver_def_path = None
    input_binary = False
    output_node_names = "output"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(model_path, output_graph_name)
    clear_devices = True

    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                              input_binary, checkpoint_path,
                              output_node_names, restore_op_name,
                              filename_tensor_name, output_graph_path,
                              clear_devices, "")

I also tried the freezing on the iris dataset in "tf.contrib.learn.datasets.load_iris". I get the same error. So I believe it is not related to the dataset.

  1. As a second step, I need to use the .pb file on the phone to predict a text. I found the camera demo example by google and it contains a lot of code. I wonder if there is a step by step tutorial how to use a Tensorflow model on Android by passing a feature vector and get the class label.

Thanks, in advance!

UPDATE

By using the recent version of tensorflow (0.12), the problem is solved. However, now, the problem is that what I should pass to output_node_names ??? How can I get what are the output nodes in the graph ?

2

There are 2 answers

6
Alexandre Passos On

Re (1) it looks like you are running freeze_graph on a build of tensorflow which does not have access to contrib ops. Maybe try explicitly importing tensorforest before calling freeze_graph?

Re (2) I don't know of a simpler example.

22
Gilbert Hendry On

CountExtremelyRandomStats is one of TensorForest's custom ops, and exists in tensorflow/contrib. As was pointed out, TF switched to including contrib ops by default at some point. I don't think there's an easy way to include the contrib custom ops in the global registry in the previous releases, because TensorForest uses the method of building a .so file that is included as a data file which is loaded at runtime (a method that was the standard when TensorForest was created, but may not be any longer). So there are no easily-included python build rules that will properly link in the C++ custom ops. You can try including tensorflow/contrib/tensor_forest:ops_lib as a dep in your build rule, but I don't think it will work.

In any case, you can try installing the nightly build of tensorflow. The alternative includes modifying how tensorforest custom ops are built, which is pretty nasty.