tf.train.MonitoredTrainingSession and reinitializable iterator from Dataset

2.4k views Asked by At

It seems as if a MonitoredTrainingSession do some operations (logging?) before the first call to .run(..), meaning that when I do:

train_data = reader.traindata() # returns a tf.contrib.data.Dataset
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes)
init_train = it.make_initializer(train_data)
ne = it.get_next()
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path)

... no calls to ts.run ...

ts.run(init_train)

This yields the error:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element

So it seams as if the MonitoredTrainingSession is doing some operations before running the operation I feed it, making it impossible to use togeather with a reinitializable iterator from Dataset.

I am sure I am missing something and would love to hear what :-)

1

There are 1 answers

0
jackberry On BEST ANSWER

Looks like there is no direct solution yet in Tensorflow. Yes it is weird that they did not give full support for Dataset API.

The reason is that the monitored session skips to run init_op when loading from the checkpoint. Hence the Iterator initializer should be a local variable.

The current work-around suggestions is given in this issue - https://github.com/tensorflow/tensorflow/issues/12859

scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(),
                                     init_train))
with tf.train.MonitoredTrainingSession(scaffold=scaffold, 
                                       checkpoint_dir=checkpoint_dir) as sess:
    while not sess.should_stop():
        sess.run(train_op)