Usage of Data Augmentation in Tensorflow 2

316 views Asked by At

I am trying to use the new TensorFlow object detection API, released in June. But I am having some difficulties in using the data augmentation utils, provided by them. This is because they import a contrib.image from TensorFlow, which is only present in the TF 1.x. So, my question is: "Anyone knows how can I use this data augmentation utils in TF 2.x?".

Best regards.

1

There are 1 answers

0
AudioBubble On

You can find the data augmentation tutorial on Tensorflow site here https://www.tensorflow.org/tutorials/images/data_augmentation which uses TF 2.x.

You can also use ImageDataGenerator library to perform Data augmentation in Tf 2.x.

tf.keras.preprocessing.image.ImageDataGenerator

Sample code snippet of imagedatagenerator

import tensorflow as tf
image = tf.keras.preprocessing.image.load_img('flower.jpeg')

image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rotation_range=40,
                                                         width_shift_range=0.2,
                                                          height_shift_range=0.2,
                                                          rescale=1./255,
                                                          shear_range=0.2,
                                                          zoom_range=0.2,
                                                          horizontal_flip=True,
                                                          fill_mode='nearest')
#convert image to array
im_array = tf.keras.preprocessing.image.img_to_array(image)
img = im_array.reshape((1,) + im_array.shape)

#Generate the images
count = 0
for batch in image_datagen.flow(img, batch_size=1, save_to_dir ='image_gen' , save_prefix='flower', save_format='jpeg'):
  count +=1
  if count==5:
    break
#Input image
    import matplotlib.pylab as plt
    image = plt.imread('flower.jpeg')
    plt.imshow(image)
    plt.show()

enter image description here

#After augmentation
import matplotlib.pylab as plt
image = plt.imread('image_gen/flower_0_1167.jpeg')
plt.imshow(image)
plt.show()

enter image description here