RandAugment for Image Classification with Keras/TensorFlow

Data augmentation has, for a long while, been serving as a means of replacing a "static" dataset with transformed variants, bolstering the invariance of Convolutional Neural Networks (CNNs), and usually leading to robustness to input.

Note: Invariance boils down to making models blind to certain perturbations, when making decisions. An image of a cat is still an image of a cat if you mirror it or rotate it.

While data augmentation in the form that we've been using it does encode a lack of knowledge about translational variance, which is important for object detection, semantic and instance segmentation, etc. - the invariance it provides is oftentimes favorable for classification models, and thus, augmentation is more commonly and more aggressively applied to classification models.

Types of Augmentation

Augmentation started being very simple - small rotations, horizontal and vertical flips, contrast or brightness fluctuations, etc. In recent years, more elaborate methods have been devised, including CutOut (spatial dropout introducing black squares randomly in the input images) and MixUp (mixing up parts of images and updating label proportions), and their combination - CutMix. The newer augmentation methods actually account for labels, and methods like CutMix change the label proportions to be equal to the proportions of the image taken up by parts of each class being mixed up.

With a growing list of possible augmentations, some have started to apply them randomly (or at least some subset of them), with the idea that a random set of augmentations will bolster the robustness of models, and replace the original set with a much larger space of input images. This is where RandAugment kicks in!

KerasCV and RandAugment

KerasCV is a separate package, but still an official addition to Keras, developed by the Keras team. This means that it gets the same amount of polish and intuitiveness of the main package, but it also integrates seemlessly with regular Keras models, and their layers. The only difference you'll ever notice is calling keras_cv.layers... instead of keras.layers....

KerasCV is still in development as of writing, and already includes 27 new preprocessing layers, RandAugment, CutMix, and MixUp being some of them. Let's take a look at what it looks like to apply RandAugment to images, and how we can train a classifier with and without random augmentation.

First, install keras_cv:

$ pip install keras_cv

Note: KerasCV requires TensorFlow 2.9 to work. If you don't already have it, run $ pip install -U tensorflow first.

Now, let's import TensorFlow, Keras and KerasCV, alongside TensorFlow datasets for easy access to Imagenette:

import tensorflow as tf
from tensorflow import keras
import keras_cv
import tensorflow_datasets as tfds

Let's load in an image and display it in its original form:

import matplotlib.pyplot as plt
import cv2

cat_img = cv2.cvtColor(cv2.imread('cat.jpg'), cv2.COLOR_BGR2RGB)
cat_img = cv2.resize(cat_img, (224, 224))
plt.imshow(cat_img)

Now, let's apply RandAugment to it, several times and take a look at the results:

fig = plt.figure(figsize=(10,10))
for i in range(16):
    ax = fig.add_subplot(4,4,i+1)
    aug_img = keras_cv.layers.RandAugment(value_range=(0, 255))(cat_img)
    # aug_img is a float-based tensor so we convert it back
    ax.imshow(aug_img.numpy().astype('int'))

The layer has a magnitude argument, which defaults to 0.5 and can be changed to increase or decrease the effect of augmentation:

fig = plt.figure(figsize=(10,10))
for i in range(16):
    ax = fig.add_subplot(4,4,i+1)
    aug_img = keras_cv.layers.RandAugment(value_range=(0, 255), magnitude=0.1)(cat_img)
    ax.imshow(aug_img.numpy().astype('int'))

When set to a low value such as 0.1 - you'll see much less aggressive augmentation:

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

Being a layer - it can be used within models or in tf.data pipelines while creating datasets. This makes RandAugment pretty flexible! Additional arguments are the augmentations_per_image and rate arguments, which work together.

For 0...augmentations_per_image, the layer adds a random preprocessing layer to the pipeline to be applied to an image. In the case of the default 3 - three different operations are added to the pipeline. Then, a random number is sampled for each augmentation in the pipeline - and if it's lower than rate (defaults to around 0.9) - the augmentation is applied.

In essence - there's a 90% probability of each (random) augmentation in the pipeline being applied to the image.

This naturally means that not all augmentations have to be applied, especially if you lower the rate. You can also customize which operations are allowed through a RandomAugmentationPipeline layer, which RandAugment is the special case of. A separate guide on RandomAugmentationPipeline will be published soon.

Training a Classifier with and without RandAugment

To simplify the data preparation/loading aspect and focus on RandAugment, let's use tfds to load in a portion of Imagenette:

(train, valid_set, test_set), info = tfds.load("imagenette", 
                                           split=["train[:70%]", "validation", "train[70%:]"],
                                           as_supervised=True, with_info=True)

class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
print(f'Class names: {class_names}') # Class names: ['n01440764', 'n02102040', 'n02979186', 'n03000684', 'n03028079', 'n03394916', 'n03417042', 'n03425413', 'n03445777', 'n03888257']
print('Num of classes:', n_classes) # Num of classes: 10

print("Train set size:", len(train)) # Train set size: 6628
print("Test set size:", len(test_set)) # Test set size: 2841
print("Valid set size:", len(valid_set)) # Valid set size: 3925

Advice: For more on loading datasets and working with tfds, as well as their splits - read our "Split Train, Test and Validation Sets with Tensorflow Datasets - tfds"

We've only loaded a portion of the training data in, to make it easier to overfit the dataset in fewer epochs (making our experiment run faster, in effect). Since the images in Imagenette are of different sizes, let's create a preprocess() function that resizes them to map the dataset with, as well as an augment() function that augments images in a tf.data.Dataset:

def preprocess(images, labels):
  return tf.image.resize(images, (224, 224)), tf.one_hot(labels, 10)
  
def augment(images, labels):
  inputs = {"images": images, "labels": labels}
  outputs = keras_cv.layers.RandAugment(value_range=(0, 255))(inputs)
  return outputs['images'], outputs['labels']

Now - we one-hot encoded the labels. We didn't necessarily have to, but for augmentations like CutMix that tamper with labels and their proportions, you'll have to. Since you might want to apply those as well as RandAugment works really well with them to create robust classifiers - let's leave the one-hot encoding in. Additionally, RandAugment takes in a dictionary with images and labels exactly because of this - some augmentations that you can add will actually change the labels, so they're mandatory. You can extract the augmented images and labels from the outputs dictionary easily, so this is an extra, yet simple, step to take during augmentation.

Let's map the existing datasets returned from tfds with the preprocess() function, batch them and augment only the training set:

valid_set = valid_set.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
train_set = train.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
train_set_aug = train.map(preprocess).map(augment_data, 
                                          num_parallel_calls=tf.data.AUTOTUNE).batch(32).prefetch(tf.data.AUTOTUNE)

Let's train a network! keras_cv.models has some built-in networks, similar to keras.applications. While the list is still short - it'll expand through time and take over keras.applications. The API is very similar, so porting code will be fairly easy for most practitioners:

# rescaling to [0..1]
effnet = keras_cv.models.EfficientNetV2B0(include_rescaling=True, include_top=True, classes=10)
          
effnet.compile(
    loss='categorical_crossentropy',
    optimizer='adam',
    metrics=['accuracy']
)

history = effnet.fit(train_set, epochs=25, validation_data = valid_set)

Alternatively, you can use the current keras.applications:

effnet = keras.applications.EfficientNetV2B0(weights=None, classes=10)

effnet.compile(
  loss='categorical_crossentropy',
  optimizer='adam',
  metrics=['accuracy']
)

history1 = effnet.fit(train_set, epochs=50, validation_data=valid_set)

This results in a model that doesn't really do super well:

Epoch 1/50
208/208 [==============================] - 60s 238ms/step - loss: 2.7742 - accuracy: 0.2313 - val_loss: 3.2200 - val_accuracy: 0.3085
...
Epoch 50/50
208/208 [==============================] - 48s 229ms/step - loss: 0.0272 - accuracy: 0.9925 - val_loss: 2.0638 - val_accuracy: 0.6887

Now, let's train the same network setup on the augmented dataset. Each batch is individually augmented, so whenever the same batch of images (in the next epoch) comes around - they'll have different augmentations:

effnet = keras.applications.EfficientNetV2B0(weights=None, classes=10)
effnet.compile(
  loss='categorical_crossentropy',
  optimizer='adam',
  metrics=['accuracy']
)

history2 = effnet.fit(train_set_aug, epochs=50, validation_data = valid_set)
Epoch 1/50
208/208 [==============================] - 141s 630ms/step - loss: 2.9966 - accuracy: 0.1314 - val_loss: 2.7398 - val_accuracy: 0.2395
...
Epoch 50/50
208/208 [==============================] - 125s 603ms/step - loss: 0.7313 - accuracy: 0.7583 - val_loss: 0.6101 - val_accuracy: 0.8143

Much better! While you'd still want to apply other augmentations, such as CutMix and MixUp, alongside other training techniques to maximize the network's accuracy - just RandAugment significantly helped and can be comparable to a longer augmentation pipeline.

If you compare the training curves, including the training and validation curves - it only becomes clear how much RandAugment helps:

In the non-augmented pipeline, the network overfits (training accuracy hits ceiling) and the validation accuracy stays low. In the augmented pipeline, the training accuracy stays lower than the validation accuracy from start to end.

With a higher training loss, the network is much more aware of the mistakes it still makes, and more updates can be made to make it invariant to the transformations. The former sees no need to update, while the latter does and raises the ceiling of potential.

Conclusion

KerasCV is a separate package, but still an official addition to Keras, developed by the Keras team, aimed at bringing industry-strength CV to your Keras projects. KerasCV is still in development as of writing, and already includes 27 new preprocessing layers, RandAugment, CutMix, and MixUp being some of them.

In this short guide, we've taken a look at how you can use RandAugment to apply a number of random transformations from a given list of applied transformations, and how easy it is to include in any Keras training pipeline.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

Project

DeepLabV3+ Semantic Segmentation with Keras

# tensorflow# keras# python# machine learning

Semantic segmentation is the process of segmenting an image into classes - effectively, performing pixel-level classification. Color edges don't necessarily have to be the boundaries...

David Landup
David Landup
Details
Project

Building Your First Convolutional Neural Network With Keras

# artificial intelligence# machine learning# keras# deep learning

Most resources start with pristine datasets, start at importing and finish at validation. There's much more to know. Why was a class predicted? Where was...

David Landup
David Landup
Details

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms