I just upgraded to tensorflow 2.3. I want to make my own data generator for training. With tensorflow 1.x, I did this:
def get_data_generator(test_flag):
item_list = load_item_list(test_flag)
print('data loaded')
while True:
X = []
Y = []
for _ in range(BATCH_SIZE):
x, y = get_random_augmented_sample(item_list)
X.append(x)
Y.append(y)
yield np.asarray(X), np.asarray(Y)
data_generator_train = get_data_generator(False)
data_generator_test = get_data_generator(True)
model.fit_generator(data_generator_train, validation_data=data_generator_test,
epochs=10000, verbose=2,
use_multiprocessing=True,
workers=8,
validation_steps=100,
steps_per_epoch=500,
)
This code worked fine with tensorflow 1.x. 8 processes were created in the system. The processor and video card were loaded perfectly. "data loaded" was printed 8 times.
With tensorflow 2.3 i got warning:
WARNING: tensorflow: multiprocessing can interact badly with TensorFlow, causing nondeterministic deadlocks. For high performance data pipelines tf.data is recommended.
"data loaded" was printed once(should 8 times). GPU is not fully utilized. It also have memory leak every epoch, so traning will stops after several epochs. use_multiprocessing flag did not help.
How to make a generator / iterator in tensorflow(keras) 2.x that can easily be parallelized across multiple CPU processes? Deadlocks and data order are not important.
With a
tf.data
pipeline, there are several spots where you can parallelize. Depending on how your data are stored and read, you can parallelize reading. You can also parallelize augmentation, and you can prefetch data as you train, so your GPU (or other hardware) is never hungry for data.In the code below, I have demonstrated how you can parallelize augmentation and add prefetching.