I get a AssertionError when passing my tf.Dataset into the tf.Keras Model's fit() method.
I am using tensorflow==2.0.0.
I checked if my dataset works by:
# for x,y in dataset:
# print(x.shape, y.shape)
which yields correct shapes for models input data.
The full trace is:
Traceback (most recent call last):
File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/anaconda3/envs/ml36/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/me/train.py", line 102, in <module>
start_training(**arguments)
File "/me/train.py", line 66, in start_training
steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 789, in fit
*args, **kwargs)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/training_distributed.py", line 776, in wrapper
mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 782, in run_distribute_coordinator
rpc_layer)
File "/anaconda3/envs/ml36/lib/python3.6/site-packages/tensorflow_core/python/distribute/distribute_coordinator.py", line 344, in _run_single_worker
assert strategy
AssertionError
I had the same error when running
gcloud ai-platform local trainon the final release of tensorflow 2.0.0. However, it was working on earlier releases. Try to downgrade to 2.0.0b1:pip install tensorflow==2.0.0b1--
Also found that you don't get this error if you run directly in python or if you run it in the cloud.