Machine Learning: Overfitting Is Your Friend, Not Your Foe

Let me preface the potentially provocative title with:

It's true, nobody wants overfitting end models, just like nobody wants underfitting end models.

Overfit models perform great on training data, but can't generalize well to new instances. What you end up with is a model that's approaching a fully hard-coded model tailored to a specific dataset.

Underfit models can't generalize to new data, but they can't model the original training set either.

The right model is one that fits the data in such a way that it performs well predicting values in the training, validation and test set, as well as new instances.

Overfitting vs. Data Scientists

Battling overfitting is given a spotlight because it's more illusory, and more tempting for a rookie to create overfit models when they start with their Machine Learning journey. Throughout books, blog posts and courses, a common scenario is given:

"This model has a 100% accuracy rate! It's perfect! Or not. Actually, it just badly overfits the dataset, and when testing it on new instances, it performs with just X%, which is equal to random guessing."

After these sections, entire book and course chapters are dedicated to battling overfitting and how to avoid it. The word itself became stigmatized as a generally bad thing. And this is where the general conception arises:

"I must avoid overfitting at all costs."

It's given much more spotlight than underfitting, which is equally as "bad". It's worth noting that "bad" is an arbitrary term, and none of these conditions are inherently "good" or "bad". Some may claim that overfit models are technically more useful, because they at least perform well on some data while underfit models perform well on no data, but the illusion of success is a good candidate for outweighing this benefit.

For reference, let's consult Google Trends and the Google Ngram Viewer. Google Trends display trends of search data, while the Google Ngram Viewer counts number of occurrences of n-grams (sequences of n items, such as words) in literature, parsing through a vast number of books through the ages:

Everybody talks about overfitting and mostly in the context of avoiding it - which oftentimes leads people to a general notion that it's inherently a bad thing.

This is true, to a degree. Yes - you don't want the end model to overfit badly, otherwise, it's practically useless. But you don't arrive at the end model right away - you tweak it numerous times, with various hyperparameters. During this process is where you shouldn't mind seeing overfitting happening - it's a good sign, though, not a good result.

How Overfitting Isn’t as Bad as It’s Made Out to Be

A model and architecture that has the ability to overfit, is more likely to have the ability to generalize well to new instances, if you simplify it (and/or tweak the data).

  • Sometimes, it isn't just about the model, as we'll see a bit later.

If a model can overfit, it has enough entropic capacity to extract features (in a meaningful and non-meaningful way) from data. From there, it's either that the model has more than required entropic capacity (complexity/power) or that the data itself isn't enough (very common case).

The reverse statement can also be true, but more rarely. If a given model or architecture underfits, you can try tweaking the model to see if it picks up certain features, but the type of model might just be plain wrong for the task and you won't be able to fit the data with it no matter what you do. Some models just get stuck at some level of accuracy, as they simply can't extract enough features to distinguish between certain classes, or predict values.

In cooking, a reverse analogy can be created. It's better to use less salt in a stew early on, as you can always add salt later to taste, but it's hard to take it away once already put in.

In Machine Learning - it's the opposite. It's better to have a model overfit, then simplify it, change hyperparameters, augment the data, etc. to make it generalize well, but it's harder (in practical settings) to do the opposite. Avoiding overfitting before it happens might very well keep you away from finding the right model and/or architecture for a longer period of time.

In practice, and in some of the most fascinating use cases of Machine Learning, and Deep Learning, you'll be working on datasets that you'll be having trouble overfitting. These will be datasets that you'll routinely be underfitting, without the ability of finding models and architectures that can generalize well and extract features.

It's also worth noting the difference between what I call true overfitting and partial overfitting. A model that overfits a dataset, and achieves 60% accuracy on the training set, with only 40% on the validation and test sets is overfitting a part of the data. However, it's not truly overfitting in the sense of eclipsing the entire dataset, and achieving a near 100% (false) accuracy rate, while its validation and test sets sit low at, say, ~40%.

A model that partially overfits isn't one that'll be able to generalize well with simplification, as it doesn't have enough entropic capacity to truly (over)fit. Once it does, my argument applies, though it doesn't guarantee success, as clarified in the proceeding sections.

Case Study - Friendly Overfitting Argument

The MNIST handwritten digits dataset, compiled by Yann LeCun is one of the classical benchmark datasets used for training classification models. LeCun is widely considered one of the founding fathers of Deep Learning - with contributions to the field that most can't put under their belt, and the MNIST handwritten digits dataset was one of the first major benchmarks used for the early stages of Convolutional Neural Networks.

It's also the most overused dataset, potentially ever.

Nothing wrong with the dataset itself, nor with LeCun who created it - it's actually pretty good, but finding example upon example on the same dataset online is boring. At one point - we overfit ourselves looking at it. How much? Here's my attempt at listing the first ten MNIST digits from the top of my head:

5, 0, 4, 1, 9, 2, 2, 4, 3

What did I do?

from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# Import and normalize the images, splitting out a validation set
(X_train_full, Y_train_full), (X_test, Y_test) = keras.datasets.mnist.load_data()

X_valid, X_train = X_train_full[:5000]/255.0, X_train_full[5000:]/255.0
Y_valid, Y_train = Y_train_full[:5000], Y_train_full[5000:]

X_test = X_test/255.0

# Print out the first ten digits
fig, ax = plt.subplots(1, 10, figsize=(10,2))
for i in range(10):
    ax[i].imshow(X_train_full[i])
    ax[i].axis('off')
    plt.subplots_adjust(wspace=1) 

plt.show()

Almost there.

I'll use this chance to make a public plea to all content creators to not overuse this dataset beyond the introductory parts, where the simplicity of the dataset can be used to lower the barrier to entry. Please.

Additionally, this dataset makes it hard to build a model that underfits. It's just too simple - and even a fairly small Multilayer Perceptron (MLP) classifier built with an intuitive number of layers and neurons per layer can easily reach upwards of 98% accuracy on the training, testing and validation set. Here's a Jupyter Notebook of a simple MLP achieving ~98% accuracy on both the training, validation and testing sets, which I spun up with sensible defaults.

I haven't even bothered to try tuning it to perform better than the initial setup.

The CIFAR10 and CIFAR100 Datasets

Let's use a dataset that's more complicated than MNIST handwritten digits, and which makes a simple MLP underfit but which is simple enough to let a decently-sized CNN to truly overfit on it. A good candidate is the CIFAR dataset.

There are 10 classes of images in CIFAR10, and 100 in CIFAR100. Additionally, the CIFAR100 dataset has 20 families of similar classes, which means the network additionally has to learn the minute differences between similar, but different classes. These are known as "fine labels" (100) and "coarse labels" (20) and predicting these is equal to predicting the specific class, or just the family it belongs to.

For instance, here's a superclass (coarse label) and it's subclasses (fine labels):

Superclass Subclasses
food containers bottles, bowls, cans, cups, plates

A cup is a cylinder, similar to a soda can, and some bottles may be too. Since these low-level features are relatively similar, it's easy to chuck them all into the "food container" category, but higher-level abstraction is required to properly guess whether something is a "cup" or a "can".

What makes this job even harder is that CIFAR10 has 6000 images per class, while CIFAR100 has 600 images per class, giving the network less images to learn the ever so subtle differences from. Cups without handles exist, and cans without ridges do too. From a profile - it might not be too easy to tell them apart.

This is where, say, a Multilayer Perceptron simply doesn't have the abstraction power to learn, and it's doomed to fail, horribly underfitting. Convolutional Neural Networks are built based on the Neocognitron, which took hints from neuroscience and the hierarchical pattern recognition that the brain performs. These networks are able to extract features like this, and excel at the task. So much so that they often overfit badly and can't be used as is in the end - where we typically sacrifice some accuracy for the sake of generalization ability.

Let's train two different network architectures on the CIFAR10 and CIFAR100 dataset as an illustration of my point.

This is also where we'll be able to see how even when a network overfits, it's no guarantee that the network itself will definitely generalize well if simplified - it might not be able to generalize if simplified, though there is a tendency. The network might be right, but the data might not be enough.

In the case of CIFAR100 - just 500 images for training (and 100 for testing) per class is not enough for a simple CNN to really generalize well on the entire 100 classes, and we'll have to perform data augmentation to help it along. Even with data augmentation, we might not get a highly accurate network as there's just so much you can do to the data. If the same architecture performs well on CIFAR10, but not CIFAR100 - it means it simply can't distinguish from some of the more fine-grained details that make the difference between cylindrical objects that we call a "cup", "can" and "bottle", for instance.

The vast majority of advanced network architectures that achieve a high accuracy on the CIFAR100 dataset perform data augmentation or otherwise expand the training set.

Most of them have to, and that's not a sign of bad engineering. In fact - the fact that we can expand these datasets and help networks generalize better is a sign of engineering ingenuity.

Additionally, I'd invite any human to try and guess what these are, if they're convinced that image classification isn't too hard with images as small as 32x32:

Is Image 4 a few oranges? Ping-pong balls? Egg yolks? Well, probably not egg yolks, but that requires prior knowledge on what "eggs" are and whether you're likely to find yolks sitting on the table, which a network won't have. Consider the amount of prior knowledge you may have regarding the world and how much it affects what you see.

Importing the Data

We'll be using Keras as the deep learning library of choice, but you can follow along with other libraries or even your custom models if you're up for it.

But first off, let's load it in, separate the data into a training, testing and validation set, normalizing the image values to 0..1:

from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

# Starting with CIFAR10
(X_train_full, Y_train_full), (X_test, Y_test) = keras.datasets.cifar10.load_data()

X_valid, X_train = X_train_full[:5000]/255.0, X_train_full[5000:]/255.0
Y_valid, Y_train = Y_train_full[:5000], Y_train_full[5000:]

X_test = X_test/255.0

Then, let's visualize some of the images in the dataset to get an idea of what we're up against:

fig, ax = plt.subplots(5, 5, figsize=(10, 10))
ax = ax.ravel()

# Labels come as numbers of [0..9], so here are the class names for humans
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

for i in range(25):
    ax[i].imshow(X_train_full[i])
    ax[i].set_title(class_names[Y_train_full[i][0]])
    ax[i].axis('off')
    plt.subplots_adjust(wspace=1) 

plt.show()

Underfitting Multilayer Perceptron

Pretty much no matter what we do, the MLP won't perform that well. It'll definitely reach some level of accuracy based on the raw sequences of information coming in - but this number is capped and probably won't be too high.

The network will start overfitting at one point, learning the concrete sequences of data denoting images, but will still have low accuracy on the training set even when overfitting, which is the prime time to stop training it, since it simply can't fit the data well. Training networks has a carbon footprint, you know.

Let's add in an EarlyStopping callback to avoid running the network beyond the point of common sense, and set the epochs to a number beyond what we'll run it for (so EarlyStopping can kick in).

We'll use the Sequential API to add a couple of layers with BatchNormalization and a bit of Dropout. They help with generalization and we want to at least try to get this model to learn something.

The main hyperparameters we can tweak here are the number of layers, their sizes, activation functions, kernel initializers and dropout rates, and here's a "decently" performing setup:

checkpoint = keras.callbacks.ModelCheckpoint("simple_dense.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.Sequential([
  keras.layers.Flatten(input_shape=[32, 32, 3]),
  keras.layers.BatchNormalization(),
  keras.layers.Dense(75),
    
  keras.layers.Dense((50), activation='elu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),
    
  keras.layers.Dense((50), activation='elu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),
    
  keras.layers.Dense(10, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Nadam(learning_rate=1e-4),
              metrics=["accuracy"])

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150, 
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])

Let's see if the starting hypothesis is true - it'll start out learning and generalizing to some extent but will end up having low accuracy on both the training set as well as the testing and validation set, resulting in an overall low accuracy.

For CIFAR10, the network performs okay-ish:

Epoch 1/150
1407/1407 [==============================] - 5s 3ms/step - loss: 1.9706 - accuracy: 0.3108 - val_loss: 1.6841 - val_accuracy: 0.4100
...
Epoch 50/150
1407/1407 [==============================] - 4s 3ms/step - loss: 1.2927 - accuracy: 0.5403 - val_loss: 1.3893 - val_accuracy: 0.5122

Let's take a look at the history of its learning:

pd.DataFrame(history.history).plot()
plt.show()

model.evaluate(X_test, Y_test)
Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

313/313 [==============================] - 0s 926us/step - loss: 1.3836 - accuracy: 0.5058
[1.383605718612671, 0.5058000087738037]

The overall accuracy gets up to ~50% and the network gets here pretty quickly and starts plateauing. 5/10 images being correctly classified sounds like tossing a coin, but remember that there are 10 classes here, so if it were randomly guessing, it'd on average guess a single image out of ten. Let's switch to the CIFAR100 dataset, which also necessitates a network with at least a tiny bit more power, as there are less training instances per class, as well as a vastly higher number of classes:

checkpoint = keras.callbacks.ModelCheckpoint("bigger_dense.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

# Changing the loaded data
(X_train_full, Y_train_full), (X_test, Y_test) = keras.datasets.cifar100.load_data()

# Modify the model
model1 = keras.Sequential([
  keras.layers.Flatten(input_shape=[32, 32, 3]),
  keras.layers.BatchNormalization(),
  keras.layers.Dense(256, activation='relu', kernel_initializer="he_normal"),
    
  keras.layers.Dense(128, activation='relu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),

  keras.layers.Dense(100, activation='softmax')
])


model1.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Nadam(learning_rate=1e-4),
              metrics=["accuracy"])

history = model1.fit(X_train, 
                    Y_train, 
                    epochs=150, 
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])

The network performs fairly badly:

Epoch 1/150
1407/1407 [==============================] - 13s 9ms/step - loss: 4.2260 - accuracy: 0.0836 - val_loss: 3.8682 - val_accuracy: 0.1238
...
Epoch 24/150
1407/1407 [==============================] - 12s 8ms/step - loss: 2.3598 - accuracy: 0.4006 - val_loss: 3.3577 - val_accuracy: 0.2434

And let's plot the history of its progress, as well as evaluate it on the testing set (which will likely perform as well as the validation set):

pd.DataFrame(history.history).plot()
plt.show()

model.evaluate(X_test, Y_test)
313/313 [==============================] - 0s 2ms/step - loss: 3.2681 - accuracy: 0.2408
[3.2681326866149902, 0.24079999327659607]

As expected, the network wasn't able to grasp the data well. It ended up having an overfit accuracy of 40%, and an actual accuracy of ~24%.

The accuracy capped at 40% - it wasn't really able to overfit the dataset, even if it overfit some parts of it that it was able to discern given the limited architecture. This model doesn't have the necessary entropic capacity required for it to truly overfit for the sake of my argument.

This model and its architecture simply isn't well suited for this task - and while we could technically get it to (over)fit more, it'll still have issues in the long-run. For instance, let's turn it into a bigger network, which would theoretically let it recognize more complex patterns:

model2 = keras.Sequential([
  keras.layers.Flatten(input_shape=[32, 32, 3]),
  keras.layers.BatchNormalization(),
  keras.layers.Dense(512, activation='relu', kernel_initializer="he_normal"),
    
  keras.layers.Dense(256, activation='relu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),
    
  keras.layers.Dense(128, activation='relu'),
  keras.layers.BatchNormalization(),
  keras.layers.Dropout(0.1),

  keras.layers.Dense(100, activation='softmax')
])

Though, this doesn't do much better at all:

Epoch 24/150
1407/1407 [==============================] - 28s 20ms/step - loss: 2.1202 - accuracy: 0.4507 - val_loss: 3.2796 - val_accuracy: 0.2528

It's much more complex (density explodes), yet it simply cannot extract much more:

model1.summary()
model2.summary()
Model: "sequential_17"
...
Total params: 845,284
Trainable params: 838,884
Non-trainable params: 6,400
_________________________________________________________________
Model: "sequential_18"
...
Total params: 1,764,324
Trainable params: 1,757,412
Non-trainable params: 6,912

Overfitting Convolutional Neural Network on CIFAR10

Now, let's try doing something different. Switching to a CNN will significantly help with extracting features from the dataset, thereby allowing the model to truly overfit, reaching much higher (illusory) accuracy.

We'll kick out the EarlyStopping callback to let it do its thing. Additionally, we won't be using Dropout layers, and instead try to force the network to learn the features through more layers.

Note: Outside of the context of trying to prove the argument, this would be horrible advice. This is the opposite of what you'd want to do by the end. Dropout helps networks generalize better, by forcing the non-dropped neurons to pick up the slack. Forcing the network to learn through more layers is more likely to lead to an overfit model.

The reason I'm purposefully doing this is to allow the network to horribly overfit as a sign of its ability to actually discern features, before simplifying it and adding Dropout to really allow it to generalize. If it reaches high (illusory) accuracy, it can extract much more than the MLP model, which means we can start simplifying it.

Let's once again use the Sequential API to build a CNN, firstly on the CIFAR10 dataset:

checkpoint = keras.callbacks.ModelCheckpoint("overcomplicated_cnn_cifar10.h5", save_best_only=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(64, 3, activation='relu', 
                        kernel_initializer="he_normal", 
                        kernel_regularizer=keras.regularizers.l2(l=0.01), 
                        padding='same', 
                        input_shape=[32, 32, 3]),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

model.summary()

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint])

Awesome, it overfit pretty quickly! Within just a few epochs, it started overfitting the data, and by epoch 31, it got up to 98%, with a lower validation accuracy:

Epoch 1/150
704/704 [==============================] - 149s 210ms/step - loss: 1.9561 - accuracy: 0.4683 - val_loss: 2.5060 - val_accuracy: 0.3760
...
Epoch 31/150
704/704 [==============================] - 149s 211ms/step - loss: 0.0610 - accuracy: 0.9841 - val_loss: 1.0433 - val_accuracy: 0.6958

Since there are only 10 output classes, even though we tried overfitting it a lot by creating an unnecessarily big CNN, the validation accuracy is still fairly high.

Simplifying the Convolutional Neural Network on CIFAR10

Now, let's simplify it to see how it'll fare with a more reasonable architecture. We'll add in BatchNormalization and Dropout as both help with the generalization:

checkpoint = keras.callbacks.ModelCheckpoint("simplified_cnn_cifar10.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(32, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.5),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(32, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

model.summary()

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])

This model has a (modest) count of 323,146 trainable parameters, compared to 1,579,178 from the previous CNN. How does it perform?

Epoch 1/150
704/704 [==============================] - 91s 127ms/step - loss: 2.1327 - accuracy: 0.3910 - val_loss: 1.5495 - val_accuracy: 0.5406
...
Epoch 52/150
704/704 [==============================] - 89s 127ms/step - loss: 0.4091 - accuracy: 0.8648 - val_loss: 0.4694 - val_accuracy: 0.8500

It actually achieves a pretty decent ~85% accuracy! Occam's Razor strikes again. Let's take a look at some of the results:

y_preds = model.predict(X_test)
print(y_preds[1])
print(np.argmax(y_preds[1]))

fig, ax = plt.subplots(6, 6, figsize=(10, 10))
ax = ax.ravel()

for i in range(0, 36):
    ax[i].imshow(X_test[i])
    ax[i].set_title("Actual: %s\nPred: %s" % (class_names[Y_test[i][0]], class_names[np.argmax(y_preds[i])]))
    ax[i].axis('off')
    plt.subplots_adjust(wspace=1)
    
plt.show()

The main misclassifications are two images in this small set - a dog was misclassified as a deer (respectable enough), but a closeup of an emu bird was classified as a cat (funny enough so we'll let it slide).

Overfitting Convolutional Neural Network on CIFAR100

What happens when we go for the CIFAR100 dataset?

checkpoint = keras.callbacks.ModelCheckpoint("overcomplicated_cnn_model_cifar100.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(32, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(256, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.BatchNormalization(),
    
    keras.layers.Dense(100, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

model.summary()

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint])
Epoch 1/150
704/704 [==============================] - 97s 137ms/step - loss: 4.1752 - accuracy: 0.1336 - val_loss: 3.9696 - val_accuracy: 0.1392
...
Epoch 42/150
704/704 [==============================] - 95s 135ms/step - loss: 0.1543 - accuracy: 0.9572 - val_loss: 4.1394 - val_accuracy: 0.4458

Wonderful! ~96% accuracy on the training set! Don't mind the ~44% validation accuracy just yet. Let's simplify the model real quick to get it to generalize better.

Failure to Generalize After Simplification

And this is where it becomes clear that the ability to overfit doesn't guarantee that the model could generalize better when simplified. In the case of CIFAR100, there aren't many training instances per class, and this will likely prevent a simplified version of the previous model to learn well. Let's try it out:

checkpoint = keras.callbacks.ModelCheckpoint("simplified_cnn_model_cifar100.h5", save_best_only=True)
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(32, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.5),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(256, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(100, activation='softmax')
])

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

history = model.fit(X_train, 
                    Y_train, 
                    epochs=150,
                    batch_size=64,
                    validation_data=(X_valid, Y_valid),
                    callbacks=[checkpoint, early_stopping])
Epoch 1/150
704/704 [==============================] - 96s 135ms/step - loss: 4.4432 - accuracy: 0.1112 - val_loss: 3.7893 - val_accuracy: 0.1702
...
Epoch 48/150
704/704 [==============================] - 92s 131ms/step - loss: 1.2550 - accuracy: 0.6370 - val_loss: 1.7147 - val_accuracy: 0.5466

It's plateauing and can't really get to generalize the data. In this case, it might not be the model's fault - maybe it's just right for the task, especially given the high accuracy on the CIFAR10 dataset, which has the same input shape and similar images in the dataset. It appears that the model can be reasonably accurate with the general shapes, but not the distinction between fine shapes.

The simpler model actually performs better than the more complicated one in terms of validation accuracy - so the more complex CNN doesn't get these fine details much better at all. Here, the problem most likely lies in the fact that there are only 500 training images per class, which really isn't enough. In the more complex network, this leads to overfitting, because there's not enough diversity - when simplified to avoid overfitting, this causes underfitting as again, there's no diversity.

This is why the vast majority of the papers linked before, and the vast majority of networks augment the data of the CIFAR100 dataset.

It's genuinely not a dataset for which it's easy to get high accuracy on, unlike the MNIST handwritten digits dataset, and a simple CNN like we're building probably won't cut it for high accuracy. Just remember the number of quite specific classes, how uninformative some of the images are, and just how much prior knowledge humans have to discern between these.

Let's do our best by augmenting a few images and artificially expanding the training data, to at least try to get a higher accuracy. Keep in mind that the CIFAR100 is, again, a genuinely difficult dataset to get high accuracy on with simple models. The state of the art models use different and novel techniques to shave off errors, and many of these models aren't even CNNs - they're Transformers.

If you'd like to take a look at the landscape of these models, PapersWithCode has done a beautiful compilation of papers, source code and results.

Data Augmentation with Keras' ImageDataGenerator Class

Will data augmentation help? Usually, it does, but with a serious lack of training data like we're facing, there's just so much you can do with random rotations, flipping, cropping, etc. If an architecture can't generalize well on a dataset, you'll likely boost it via data augmentation, but it probably won't be a whole lot.

That being said, let's use Keras' ImageDataGenerator class to try and generate some new training data with random changes, in hopes of improving the model's accuracy. If it does improve, it shouldn't be by a huge amount, and it'll likely get back to partially overfitting the dataset without an ability to either generalize well or fully overfit the data.

Given the constant random variations in the data, the model is less likely to overfit on the same number of epochs, as the variations make it keep adjusting to "new" data. Let's run it for, say, 300 epochs, which is significantly more than the rest of the networks we've trained. This is possible without major overfitting, again, due to the random modifications made to the images while they're flowing in:

checkpoint = keras.callbacks.ModelCheckpoint("augmented_cnn.h5", save_best_only=True)

model = keras.models.Sequential([
    keras.layers.Conv2D(64, 3, activation='relu', kernel_initializer="he_normal", kernel_regularizer=keras.regularizers.l2(l=0.01), padding='same', input_shape=[32, 32, 3]),
    keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.Conv2D(128, 2, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.Conv2D(256, 3, activation='relu', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPooling2D(2),
    keras.layers.Dropout(0.4),
    
    keras.layers.Flatten(),    
    keras.layers.Dense(512, activation='relu'),
    keras.layers.BatchNormalization(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(100, activation='softmax')
])

    
train_datagen = ImageDataGenerator(rotation_range=30,
        height_shift_range=0.2,
        width_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode='nearest')

valid_datagen = ImageDataGenerator()

train_datagen.fit(X_train)
valid_datagen.fit(X_valid)

train_generator = train_datagen.flow(X_train, Y_train, batch_size=128)
valid_generator = valid_datagen.flow(X_valid, Y_valid, batch_size=128)

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=1e-3, decay=1e-6),
              metrics=["accuracy"])

history = model.fit(train_generator, 
                    epochs=300,
                    batch_size=128,
                    steps_per_epoch=len(X_train)//128,
                    validation_data=valid_generator,
                    callbacks=[checkpoint])
Epoch 1/300
351/351 [==============================] - 16s 44ms/step - loss: 5.3788 - accuracy: 0.0487 - val_loss: 5.3474 - val_accuracy: 0.0440
...
Epoch 300/300
351/351 [==============================] - 15s 43ms/step - loss: 1.0571 - accuracy: 0.6895 - val_loss: 2.0005 - val_accuracy: 0.5532

The model is performing with ~55% on the validation set, and is still overfitting the data partially. The val_loss has stopped going down, and is quite rocky, even with a higher batch_size.

This network simply can't learn and fit the data with high accuracy, even though variations of it do have the entropic capacity to overfit the data.

Conclusion?

Overfitting isn't inherently a bad thing - it's just a thing. No, you don't want overfit end-models, but it shouldn't be treated as the plague and can even be a good sign that a model could perform better given more data and a simplification step. This isn't guaranteed, by any means, and the CIFAR100 dataset has been used as an example of a dataset that's not easy to generalize well to.

The point of this rambling is, again, not to be contrarian - but to incite discussion on the topic, which doesn't appear to be taking much place.

Who am I to make this claim?

Just someone who sits home, practicing the craft, with a deep fascination towards tomorrow.

Do I have the ability to be wrong?

Very much so.

How should you take this piece?

Take it as you may - think for yourself if it makes sense or not. If you don't think I'm out of place for noting this, let me know. If you think I'm wrong on this - by all means, please feel let me know and don't mince your words. :)

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2025 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms