Apply different data augmentation to part of the train set based on the category

3.4k views Asked by At

I'm working on a machine learning process to classify images. My problem is that my dataset is imbalanced, and in my 5 categories of images, I have about 400 images in of one class, and about 20 images of each of the other classes.

I would like to balance my train set by applying data augmentation only to certain classes of my train set.

Here's the code I'm using for creating the train an validation sets:

# Import data
data_dir = pathlib.Path(r"C:\Train set")

# Define train and validation sets (80% - 20%)
batch_size = 32
img_height = 240
img_width = 240

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

And here's how I apply data augmentation, although this would be for the entire train set:

# Apply data augmentation
data_augmentation = keras.Sequential(
  [
    layers.experimental.preprocessing.RandomFlip("horizontal", 
                                                 input_shape=(img_height, 
                                                              img_width,
                                                              3)),
    layers.experimental.preprocessing.RandomRotation(0.1),
    layers.experimental.preprocessing.RandomZoom(0.1),
  ]
)

Is there any way to go into my train set, extract those categories that have fewer images, and apply data augmentation only to them?

Thanks in advance!

1

There are 1 answers

6
Nicolas Gervais - Open to Work On BEST ANSWER

I suggest not using ImageDataGenerator but a customized tf.data.Dataset. In a mapping operation, you can treat categories differently, e.g.:

def preprocess(filepath):
    category = tf.strings.split(filepath, os.sep)[0]
    read_file = tf.io.read_file(filepath)
    decode = tf.image.decode_jpeg(read_file, channels=3)
    resize = tf.image.resize(decode, (200, 200))
    image = tf.expand_dims(resize, 0)
    if tf.equal(category, 'tf_astronauts'):
        image = tf.image.flip_up_down(image)
        image = tf.image.flip_left_right(image)
    # image = tf.image.convert_image_dtype(image, tf.float32)
    # category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)
    return image, category

Let me demonstrate it. Let's make you a folder with training images:

import tensorflow as tf
import matplotlib.pyplot as plt
import cv2
from skimage import data
from glob2 import glob
import os

cat = data.chelsea()
astronaut = data.astronaut()

for category, picture in zip(['tf_cats', 'tf_astronauts'], [cat, astronaut]):
    os.makedirs(category, exist_ok=True)
    for i in range(5):
        cv2.imwrite(os.path.join(category, category + f'_{i}.jpg'),
                    cv2.cvtColor(picture, cv2.COLOR_RGB2BGR))

files = glob('tf_*\\*.jpg')

Now you have these files:

['tf_astronauts\\tf_astronauts_0.jpg',
 'tf_astronauts\\tf_astronauts_1.jpg',
 'tf_astronauts\\tf_astronauts_2.jpg',
 'tf_astronauts\\tf_astronauts_3.jpg',
 'tf_astronauts\\tf_astronauts_4.jpg',
 'tf_cats\\tf_cats_0.jpg',
 'tf_cats\\tf_cats_1.jpg',
 'tf_cats\\tf_cats_2.jpg',
 'tf_cats\\tf_cats_3.jpg',
 'tf_cats\\tf_cats_4.jpg']

Let's apply tranformations only to the astronaut category. Let's use the tf.image transformations.

def preprocess(filepath):
    category = tf.strings.split(filepath, os.sep)[0]
    read_file = tf.io.read_file(filepath)
    decode = tf.image.decode_jpeg(read_file, channels=3)
    resize = tf.image.resize(decode, (200, 200))
    image = tf.expand_dims(resize, 0)
    if tf.equal(category, 'tf_astronauts'):
        image = tf.image.flip_up_down(image)
        image = tf.image.flip_left_right(image)
    # image = tf.image.convert_image_dtype(image, tf.float32)
    # category = tf.cast(tf.equal(category, 'tf_astronauts'), tf.int32)
    return image, category

Then, we make the tf.data.Dataset:

train = tf.data.Dataset.from_tensor_slices(files).\
    shuffle(10).take(4).map(preprocess).batch(4)

And when you iterate the dataset, you'll see that only the astronaut is flipped:

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(train))
for index, (image, label) in enumerate(zip(images, labels)):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(label.numpy().decode())
    ax.imshow(image[0].numpy().astype(int))
plt.show()

enter image description here

Please note, for training you will need to uncomment the two lines in preprocess so it returns an array of floats and an integer.