TensorFlow `AssertionError` on `fit()` method

837 views Asked by At

I get a AssertionError when passing my tf.Dataset into the tf.Keras Model's fit() method.

I am using tensorflow==2.0.0.

I checked if my dataset works by:

# for x,y in dataset:
#     print(x.shape, y.shape)

which yields correct shapes for models input data.

The full trace is:

Traceback (most recent call last):
  File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/me/train.py", line 102, in <module>
    start_training(**arguments)
  File "/me/train.py", line 66, in start_training
    steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
    use_multiprocessing=use_multiprocessing)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 789, in fit
    *args, **kwargs)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 776, in wrapper
    mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 782, in run_distribute_coordinator
    rpc_layer)
  File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
    assert strategy
AssertionError
2

There are 2 answers

0
Antony Harfield On BEST ANSWER

I had the same error when running gcloud ai-platform local train on the final release of tensorflow 2.0.0. However, it was working on earlier releases. Try to downgrade to 2.0.0b1:

pip install tensorflow==2.0.0b1

--

Also found that you don't get this error if you run directly in python or if you run it in the cloud.

0
Aleksey Vlasenko On

If you are training locally without using any distributed strategies you can add following lines to your code to solve this issue:

  TF_CONFIG = os.environ.get('TF_CONFIG')
  if TF_CONFIG:
    os.environ.pop('TF_CONFIG')