tensorflow monitoredsession usage

2.8k views Asked by At

I have the following code to perform simple arithmetic calculations. I am trying to implement fault tolerance in it by using a Monitored Training session.

import tensorflow as tf

global_step_tensor = tf.Variable(10, trainable=False, name='global_step')

cluster = tf.train.ClusterSpec({"local": ["localhost:2222", "localhost:2223","localhost:2224", "localhost:2225"]})
x = tf.constant(2)

with tf.device("/job:local/task:0"):
    y1 = x + 300

with tf.device("/job:local/task:1"):
    y2 = x**2

with tf.device("/job:local/task:2"):
    y3 = 5*x

with tf.device("/job:local/task:3"):
    y0 = x - 66
    y = y0 + y1 + y2 + y3

ChiefSessionCreator = tf.train.ChiefSessionCreator(scaffold=None, master='localhost:2222', config='grpc://localhost:2222', checkpoint_dir='/home/tensorflow/codes/checkpoints')
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir='/home/tensorflow/codes/checkpoints', save_secs=10, save_steps=None, saver=y, checkpoint_basename='model.ckpt', scaffold=None)
summary_hook = tf.train.SummarySaverHook(save_steps=None, save_secs=10, output_dir='/home/tensorflow/codes/savepoints', summary_writer=None, scaffold=None, summary_op=y)

with tf.train.MonitoredTrainingSession(master='localhost:2222', is_chief=True, checkpoint_dir='/home/tensorflow/codes/checkpoints', 
    scaffold=None, hooks=[saver_hook, summary_hook], chief_only_hooks=None, save_checkpoint_secs=10, save_summaries_steps=None, config='grpc://localhost:2222') as sess:

    while not sess.should_stop():
        sess.run(model)

    while not sess.should_stop():
        print(sess.run(y0))
        print('\n')

    while not sess.should_stop():
        print(sess.run(y1))
        print('\n')

    while not sess.should_stop():
        print(sess.run(y2))
        print('\n')

    while not sess.should_stop():
        print(sess.run(y3))
        print('\n')

    while not sess.should_stop():
        result = sess.run(y)
        print(result)

But it is throwing the following error:-

 Traceback (most recent call last):
  File "add_1.py", line 36, in <module>
    scaffold=None, hooks=[saver_hook, summary_hook], chief_only_hooks=None, save_checkpoint_secs=10, save_summaries_steps=None, config='grpc://localhost:2222') as sess:
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 289, in MonitoredTrainingSession
    return MonitoredSession(session_creator=session_creator, hooks=hooks)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 447, in __init__
    self._sess = _RecoverableSession(self._coordinated_creator)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 618, in __init__
    _WrappedSession.__init__(self, self._sess_creator.create_session())
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 505, in create_session
    self.tf_sess = self._session_creator.create_session()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/monitored_session.py", line 341, in create_session
    init_fn=self._scaffold.init_fn)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/session_manager.py", line 227, in prepare_session
    config=config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/session_manager.py", line 153, in _restore_checkpoint
    sess = session.Session(self._target, graph=self._graph, config=config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1186, in __init__
    super(Session, self).__init__(target, graph, config=config)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 540, in __init__
    % type(config))
TypeError: config must be a tf.ConfigProto, but got <type 'str'>
Exception AttributeError: "'Session' object has no attribute '_session'" in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x7fb1bac14ed0>> ignored

which in my opinion is due to incorrect argument given to config. Am I using the parameters right? Please advice.

1

There are 1 answers

0
user1454804 On

First issue is in following lines. It uses a local session for a distributed (device assigned) op. Why do you need that?

sess = tf.Session()
tf.train.global_step(sess, global_step_tensor)

Second issue: Code uses WorkerSessionCreator. One machine should be chief. In this case ChiefSessionCreator should be used. I recommend to use following tf.train.MonitoredTrainingSession.

Third issue: sess.should_stop() should be checked before each run.