I am trying to understand more about certain surprising results i see in implementing a tf graph . The graph i am working with is just a forest (bunch of trees). This is just a plain forward inference graph , and nothing related to training. I am sharing the snippets for 2 implementation
code snippet 1:
with tf.name_scope("main"):
def get_tree_output(offset):
loop_vars = (offset,)
leaf_indice = tf.while_loop(cond,
body,
loop_vars,
back_prop=False,
parallel_iterations=1,
name="while_loop")
tree_score = tf.gather(score_tensor, leaf_indice, name="tree-scores")
output = tf.add(tree_score, output)
leaf_indices = tf.map_fn(get_tree_output,
tree_offsets_tensor,
dtype=INT_TYPE,
parallel_iterations=n_trees,
back_prop=False,
name="tree-scores")
tree_scores = tf.gather(score_tensor, leaf_indices, name="tree-scores")
output = tf.reduce_sum(tree_scores, name="sum-output")
output = tf.sigmoid(output, name="sigmoid-output")
code snippet 2:
with tf.name_scope("main"):
tree_offsets_tensor = tf.constant(tree_offsets, dtype=INT_TYPE, name="tree_offsets_tensor")
loop_vars = (tree_offsets_tensor,)
leaf_indices = tf.while_loop(cond,
body,
loop_vars,
back_prop=False,
parallel_iterations=n_trees,
name="while_loop")
tree_scores = tf.gather(score_tensor, leaf_indices, name="tree-scores")
output = tf.reduce_sum(tree_scores, name="sum-output")
output = tf.sigmoid(output, name="sigmoid-output")
The rest of the code is exactly the same : the constant tensors , variables, condition and body for the while loop. thread and parallelism was also the same in both case code snippet2 : takes about 500 micro sec to do inference code snippet 1 : take about 12 milli sec to do inference
The difference is that in snippet 1 , I use map_fn
to operate on tree_offset_tensor
, where as in snippet 2 , I get rid of that map_fn
, and just directly use that tensor, so as I understand in snippet1 get_tree_output
method gets called with one element from tree_offset_tensor
, we are having multiple while_loop
for each individual offset value, whereas in snippet 2 we just have one while_loop
that just takes multiple offset values (basically the offset_tensor).
I also tried another variation for snippet , instead of using the map_fn I write a hand written for loop
code snippet 1 (variation for loop) :
output = 0
with tf.name_scope("main"):
for offset in tree_offsets:
loop_vars = (offset,)
leaf_indice = tf.while_loop(cond,
body,
loop_vars,
back_prop=False,
parallel_iterations=1,
name="while_loop")
tree_score = tf.gather(score_tensor, leaf_indice, name="tree-scores")
output = tf.add(tree_score, output)
#leaf_indices = tf.map_fn(get_tree_output,
# tree_offsets_tensor, dtype=INT_TYPE,
# parallel_iterations=n_trees, back_prop=False,
# name="tree-scores")
#tree_scores = tf.gather(score_tensor, leaf_indices, name="tree-scores")
#output = tf.reduce_sum(tree_scores, name="sum-output")
output = tf.sigmoid(output, name="sigmoid-output")
This gives minor improvement : 9 millisec