How to use feedable iterator from Tensorflow Dataset API along with MonitoredTrainingSession?

3.5k views Asked by At

Tensorflow programmer's guide recommends using feedable iterator to switch between training and validation dataset without reinitializing the iterator. It mainly requires to feed the handle to choose between them.

How to use it along with tf.train.MonitoredTrainingSession?

The following method fails with "RuntimeError: Graph is finalized and cannot be modified." error.

with tf.train.MonitoredTrainingSession() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

How to achieve both the convenience of MonitoredTrainingSession and iterating training and validation datasets simultaneously?

3

There are 3 answers

0
jackberry On BEST ANSWER

I got the answer from the Tensorflow GitHub issue - https://github.com/tensorflow/tensorflow/issues/12859

The solution is to invoke the iterator.string_handle() before creating the MonitoredSession.

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator

dataset_train = Dataset.range(10)
dataset_val = Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes)
next_batch = iterator.get_next()

with tf.train.MonitoredTrainingSession() as sess:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    for step in range(10):
        print('train', sess.run(next_batch, feed_dict={handle: handle_train}))

        if step % 3 == 0:
            print('val', sess.run(next_batch, feed_dict={handle: handle_val}))

Output:
('train', 0)
('val', 90)
('train', 1)
('train', 2)
('val', 91)
('train', 3)
0
Max F. On

@Michael Jaison G answer is correct. However, it does not work when you also want to use certain session_run_hooks that need to evaluate parts of the graph, like e.g. LoggingTensorHook or SummarySaverHook. The example below will cause an error:

import tensorflow as tf

dataset_train = tf.data.Dataset.range(10)
dataset_val = tf.data.Dataset.range(90, 100)

iter_train_handle = dataset_train.make_one_shot_iterator().string_handle()
iter_val_handle = dataset_val.make_one_shot_iterator().string_handle()

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, dataset_train.output_types, dataset_train.output_shapes)
feature = iterator.get_next()

pred = feature * feature
tf.summary.scalar('pred', pred)
global_step = tf.train.create_global_step()

summary_hook = tf.train.SummarySaverHook(save_steps=5,
                                         output_dir="summaries", summary_op=tf.summary.merge_all())

with tf.train.MonitoredTrainingSession(hooks=[summary_hook]) as sess: 
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])

    for step in range(10):
        feat = sess.run(feature, feed_dict={handle: handle_train})
        pred_ = sess.run(pred, feed_dict={handle: handle_train})
        print('train: ', feat)
        print('pred: ', pred_)

        if step % 3 == 0:
            print('val', sess.run(feature, feed_dict={handle: handle_val}))

This will fail with error:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'Placeholder' with dtype string
     [[Node: Placeholder = Placeholder[dtype=DT_STRING, shape=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
     [[Node: cond/Switch_1/_15 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_18_cond/Switch_1", tensor_type=DT_INT64, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

The reason being that the hook will try to evaluate the graph already upon the first session.run([iter_train_handle, iter_val_handle]) which obviously does not contain a handle in the feed_dict yet.

The workaround solution being to overwrite the hooks that cause the problem and changing the code in before_run and after_run to only evaluate on session.run calls containing the handle in the feed_dict (you can access the feed_dict of the current session.run call via the run_context argument of before_run and after_run)

Or you can use the latest master of Tensorflow (post-1.4) which adds a run_step_fn function to MonitoredSession which allows you to specify the following step_fn which will avoid the error (on the expense of evaluating the if statement TrainingIteration number of times ...)

def step_fn(step_context):
  if handle_train is None:
    handle_train, handle_val = sess.run([iter_train_handle, iter_val_handle])
  return step_context.run_with_hooks(fetches=..., feed_dict=...)
1
spark On

There is a demo for using placeholder in mot_session with SessionRunHook. This demo is about switching datasets by feeding diff handle_string.

BTW, I have tried all solutions, but only this works.

dataset_switching