Practical Deep Learning for Computer Vision with Python - Image Classification with Transfer Learning - Creating Cutting Edge CNN Models

Image Classification with Transfer Learning - Creating Cutting Edge CNN Models

David Landup
David Landup

New models are being released and benchmarked against community-accepted datasets frequently, and keeping up with all of them is getting harder.

Most of these models are open source, and you can implement them yourself as well.

This means that the average enthusiast can load in and play around with the cutting edge models in their home, on very average machines, not only to gain a deeper understanding and appreciation of the craft, but also to contribute to the scientific discourse and publish their own improvements whenever they're made.

In this lesson, you'll learn how to use pre-trained, cutting edge Deep Learning models for Image Classification and repurpose them for your own specific application. This way, you're leveraging their high performance, ingenious architectures and someone else's training time - while applying these models to your own domain!

Transfer Learning for Computer Vision and Convolutional Neural Networks

Knowledge and knowledge representations are very universal. A computer vision model trained on one dataset learns to recognize patterns that might be very prevalent in many other datasets.

Notably, in "Deep Learning for the Life Sciences", by Bharath Ramsundar, Peter Eastman, Patrick Walters and Vijay Pande, it's noted that:

"There have been multiple studies looking into the use of recommendation system algorithms for use in molecular binding prediction. Machine learning architectures used in one field tend to carry over to other fields, so it’s important to retain the flexibility needed for innovative work."

For instance, straight and curved lines, which are typically learned at a lower level of a CNN hierarchy are bound to be present in practically all datasets. Some high-level features, such as the ones that distinguish a bee from an ant are going to be represented and learned much higher in the hierarchy:

The "fine line" between these is what you can reuse! Depending on the level of similarity between your dataset and the one a model's been pre-trained on, you may be able to reuse a small or large portion of it.

A model that classifies human-made structures (trained on a dataset such as the Places365) and a model that classifies general images (trained on a dataset such as ImageNet) are bound to have some shared patterns, although, not a lot.

You might want to train a model to distinguish, say, buses and cars for a self-driving car's vision system. You may also reasonably choose to use a very performant architecture that has proven to work well on datasets similar to yours. Then, the long process of training begins, and you end up having a performant model of your own!

However, if another model is likely to have similar representations on lower and higher levels of abstraction, there's no need to re-train a model from scratch. You may decide to use some of the already pre-trained weights, which are just as applicable to your own application of the model as they were applicable to the creator of the original architecture. You'd be transferring some of the knowledge from an already existing model to a new one, and this is known as Transfer Learning. The importance and versatility of transfer learning is, in my opinion, understated. It's oftentimes put to the side, or briefly mentioned at the end of lessons and lectures, and it's oftentimes the last concept covered when learning about CNNs.

Whenever you're reading about the application of computer vision to a specific problem - chances are, it's transfer learning in the background. If you spend your afternoons like me reading research papers in various fields (that I have barely any knowledge in), you'll notice how commonly transfer learning is used, even when not mentioned by that name. It's so prevalent that it's practically assumed that transfer learning is used. With pre-loaded models and transferred knowledge - almost anyone can utilize the power of deep learning to further a field.

  • Doctors can use computer vision models to diagnose images (X-ray, histology, retinoscopy, etc.)
  • Cities can use computer vision to detect pedestrians and cars on streets and adapt traffic lights to optimize the flow of transportation
  • Malls can use computer vision to detect parking occupancy
  • Marine biologists can use computer vision to detect endangered coral reefs (TensorFlow's Great Barrier Reef competition)
  • Manufacturers can use computer vision to detect defects in production lines (such as missing pills in medicine)
  • News outlets can use computer vision to digitize old newspaper issues
  • Agricultural plants can use computer vision to detect crop yields and health (and insects/other pests)

From optimizing workflows and investments to saving human lives - computer vision is very applicable. Though - read through the list again. Who are the people using these technologies? Doctors, biologists, farmers, city planners. They might not have a extensive computer/data science background or powerful hardware required to train large networks, but they can see benefits from those networks even if they're not optimized. Through democratized models, they don't need a data science background. Through free and cheap cloud training providers and pre-trained networks, they don't need powerful hardware.

Training with pre-built architectures and downloadable weights has become so streamlined that a kid with a slow internet connection and barely working computer can create more accurate models than top-of-the-line equipment and professionals could a decade or two ago.

The benefit of transfer learning isn't limited to shortening training. If you don't have a lot of data, a network won't be able to learn some of the distinctions early on. If you train it extensively on one dataset, and then transfer to another one, a lot of the representations are already learned and it can be fine-tuned on the new dataset. In the case of CIFAR100 we've worked with in the last lesson - many of the images can be found (in larger sizes) in datasets like ImageNet, and a lot could've been transferred with a pretrained model. This would, in effect, be what data augmentation saught to be - expanding the dataset (albeit, indirectly), with instances of data from another dataset. While you don't really expand the new dataset, the knowledge encoded in the model being fine-tuned is transfered between them.

The closer the dataset of a pre-trained model is to your own, the more you can transfer. The more you can transfer, the more of your own time and computation you can save. It's worth remembering that training neural networks does have a carbon footprint, so you're not only saving time!

Typically, transfer learning is done by loading a pre-trained model, and freezing its layers. In many cases, you can just cut off the classification layer (the final layers, or, head/densly-connected layer) and just re-train the a new top, while keeping all of the other abstraction layers intact. This is paramount to using the convolutional base as a feature extractor, and you just re-train the classifier which contains all of the domain knowledge (the convolutional blocks are much more generic). In other cases, you may decide to re-train several layers in the convolutional hierarchy alongside the top, and this is typically done when the datasets contain sufficiently different data points that re-training multiple layers is warranted. You may also decide to re-train the entirety of the model to fine-tune all of the layers.

These two approaches can be summarized as:

  • Using the Convolutional Network as a Feature Extractor
  • Fine-Tuning the Convolutional Network

In the former, you use the underlying model as a fixed feature extractor, and just train a dense network on top to discern between these features. In the latter, you fine-tune the entire (or portion of the) convolutional network, if it doesn't already have representative feature maps for some other more specific dataset, while also relying on the already trained feature maps and just updating them to also fit your own needs.

Here's a visual representation of how Transfer Learning works:

Established and Cutting Edge Image Classification Models

Many models exist out there, and for well-known datasets, you're likely to find hundreds of well-performing models published in online repositories and papers. A good holistic view of models trained on the ImageNet dataset can be seen at PapersWithCode.

Some of the well-known published architectures that have subsequently been ported into many Deep Learning frameworks include:

  • EfficientNet
  • SENet
  • Inception and Xception
  • ResNet
  • VGGNet

Note: Being well-known doesn't mean that an architecture is going to perform at the state of the art. For example, you probably don't want to use VGGNet for transfer learning, because newer, more robust, more efficient architectures have been ported and pre-trained.

The list of models on PapersWithCode is constantly being updated, and you shouldn't hang up on the position of these models there. Many of the new models that take the top places are actually based on the ones outlined in the list above.

Unfortunately, some of the newest models aren't ported as pre-trained models within frameworks such as Tensorflow and PyTorch, though, the teams are pretty dilligent in porting them with pre-trained weights. It's not like you'll be losing out on a lot of the performance, so going with any of the well-established ones isn't really bad at all.

Transfer Learning with Keras - Adapting Existing Models

With Keras, the pre-trained models are available under the tensorflow.keras.applications module. Each model has its own sub-module and class. When loading a model in, you can set a couple of optional arguments to control how the models are being loaded.

Note: You can find the ported models at, but the list doesn't include the newest and experimental models. For a more up-to-date list, visit TensorFlow's Docs.

For instance, the weights argument, if present, defines which pre-trained weights are to be used. If omitted, only the architecture (untrained network) will be loaded in. If you supply an 'imagenet' argument - a pre-trained network will be returned for that dataset. Alternatively, you can provide a path to a file with the weights you want to load in (as long as it's the exact same architecture).

Additionally, since you'll most likely be removing the top layer(s) for Transfer Learning, the include_top argument is used to define whether the top layer(s) should be present or not!

import tensorflow.keras.applications as models

# 98 MB
resnet = models.resnet50.ResNet50(weights='imagenet', include_top=False)
# 528MB
vgg16 = models.vgg16.VGG16(weights='imagenet', include_top=False)
# 23MB
nnm = models.NASNetMobile(weights='imagenet', include_top=False)
# etc...

Note: If you've never loaded pre-trained models before, they'll be downloaded over an internet connection. This may take anywhere between a few seconds and a couple of minutes, depending on your internet speed and the size of the models. The size of pre-trained models spans from as little as 14MB (typically lower for Mobile models) to as high as 549MB.

EfficientNet is a family of networks that are quite performant, scalable and, well, efficient. They were made with reducing learnable parameters in mind, so they only have 4M parameters to train. Consider that VGG19, for instance, has 139M. On a home setup, this also helps with training times significantly!

Let's load in one of the members of the EfficientNet family - EfficientNetB0:

effnet = keras.applications.EfficientNetB0(weights='imagenet', include_top=False)

This results in:

Model: "efficientnetb0"
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, None, None,  0           []                               
 rescaling_1 (Rescaling)        (None, None, None,   0           ['input_2[0][0]']                

 block7a_project_bn (BatchNorma  (None, None, None,   1280       ['block7a_project_conv[0][0]']   
 lization)                      320)                                                              
 top_conv (Conv2D)              (None, None, None,   409600      ['block7a_project_bn[0][0]']     
 top_bn (BatchNormalization)    (None, None, None,   5120        ['top_conv[0][0]']               
 top_activation (Activation)    (None, None, None,   0           ['top_bn[0][0]']                 
Total params: 4,049,571
Trainable params: 4,007,548
Non-trainable params: 42,023

On the other hand, if we were to load in EfficientNetB0 with the top included, we'd also have a few new layers at the end, that were trained to classify the data for ImageNet. This is the top of the model that we'll be training ourselves for our own application:

effnet = keras.applications.EfficientNetB0(weights='imagenet', include_top=True)

This would include the final top layers, with a Dense classifier in the end:

Model: "efficientnetb0"
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
 rescaling (Rescaling)          (None, 224, 224, 3)  0           ['input_1[0][0]']      

block7a_project_bn (BatchNorma  (None, 7, 7, 320)   1280        ['block7a_project_conv[0][0]']   
 top_conv (Conv2D)              (None, 7, 7, 1280)   409600      ['block7a_project_bn[0][0]']     
 top_bn (BatchNormalization)    (None, 7, 7, 1280)   5120        ['top_conv[0][0]']               
 top_activation (Activation)    (None, 7, 7, 1280)   0           ['top_bn[0][0]']                 
 avg_pool (GlobalAveragePooling  (None, 1280)        0           ['top_activation[0][0]']         
 top_dropout (Dropout)          (None, 1280)         0           ['avg_pool[0][0]']               
 predictions (Dense)            (None, 1000)         1281000     ['top_dropout[0][0]']            
Total params: 5,330,571
Trainable params: 5,288,548
Non-trainable params: 42,023

Their names start with top_ to annotate the fact that they belong to the top classifier.

Note: This structure may change through time. In an earlier version of Keras, the top_conv, top_bn and top_activation weren't loaded in if the include_top argument was set to False, while in the newest version, they do (and their names still have the top_ prefix, making it a bit more confusing. Always check what the "top" is in a model, before defining your own, whether it's inspired by the original implementation or not.

We won't be using the top layers, as we'll be adding our own top to the EfficientNet model and re-training only the ones we add on top (before fine-tuning the convolutional base). It is worth noting what the architecture is originally using for the top, though! They seem to be using a GlobalAveragePooling2D and Dropout before the final Dense classification layer. These tops are typically optimized for networks, so it's wise to reuse the structure at least for the baseline.

Preprocessing Input for Pre-trained Models

Note: Data preprocessing plays a crucial role in model training, and most models will have different preprocessing pipelines. You don't have to perform guesswork here! Where applicable, a model comes with its own preprocess_input() function.

The preprocess_input() function applies the same preprocessing steps to the input as they were applied during training. You can import the function from the respective module of the model, if a model resides in its own module. For instance, ResNets have their own preprocess_input function:

from keras.applications.resnet50 import preprocess_input

That being said, loading in a model, preprocessing input for it and predicting a result in Keras is as easy as:

import tensorflow.keras.applications as models
from keras.applications.resnet50 import preprocess_input

resnet50 = models.ResNet50(weights='imagenet', include_top=True)

img = # get data
img = preprocess_input(img)
pred = resnet50.predict(img)

Note: Not all models have a dedicated preprocess_input() function, because the preprocessing is done within the model itself. For instance, EfficientNet that we'll be using doesn't have its own dedicated preprocessing function, since the preprocessing layers within the model take care of that. This is becoming more and more common.

That's it! Now, since the pred array doesn't really contain human-readable data, you can also import the decode_predictions() function alongside the preprocess_input() function from a module. Alternatively, you can import the generic decode_predictions() function that also applies to models that don't have their dedicated modules:

from keras.applications.model_name import preprocess_input, decode_predictions
# OR
from keras.applications.imagenet_utils import decode_predictions
# ...

Tying this together, let's get an image of a black bear via urllib, save that file into a target size suitable for EfficientNet (the input layer expects a shape of (batch_size, 224, 224, 3)) and classify it with the pre-trained model:

from tensorflow import keras
from keras.applications.family_name import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image

import urllib.request
import matplotlib.pyplot as plt
import numpy as np

# Public domain image
url = ''
urllib.request.urlretrieve(url, 'bear.jpg')

# Load image and resize (doesn't keep aspect ratio)
img = image.load_img('bear.jpg', target_size=(224, 224))
# Turn to array of shape (224, 224, 3)
img = image.img_to_array(img)
# Expand array into (1, 224, 224, 3)
img = np.expand_dims(img, 0)
# Preprocess for models that have specific preprocess_input() function
# img_preprocessed = preprocess_input(img)

# Load model and run prediction
effnet = keras.applications.EfficientNetB0(weights='imagenet', include_top=True)
pred = effnet.predict(img)

We got the image from a URL, however - you can fetch the image from a mobile device, a REST API call, or any other source and classify it. Really - using a pre-trained classifier is as easy as importing it, feeding an image into it and decoding the results. You can serve a computer vision model to an end-user with only a few lines of code! This results in:

('n02133161', 'American_black_bear', 0.6024658),
('n02132136', 'brown_bear', 0.1457715),
('n02134418', 'sloth_bear', 0.09819221),
('n02510455', 'giant_panda', 0.0069221947),
('n02509815', 'lesser_panda', 0.005077324)

It's fairly certain that the image is an image of an American Black Bear, which is right! When preprocessed with a preprocessing function, the image may change significantly. For instance, ResNet's preprocessing function would change the color of the bear's fur:

It looks a lot more brown now! If we were to feed this image into EfficientNet, it'd think it's a brown bear:

('n02132136', 'brown_bear', 0.7152758), 
('n02133161', 'American_black_bear', 0.15667434), 
('n02134418', 'sloth_bear', 0.012813852), 
('n02134084', 'ice_bear', 0.0067828503), ('n02117135', 'hyena', 0.0050422684)

It's important not to mix and match preprocessing functions between models. For instance, ResNet learns that what we see as brown is called black, since the color got changed through preprocessing, and it only ever saw what we call "brown" with the label "black". Now, it wasn't trained to classify colors - but it was trained to classify between a black bear and a brown bear, and the colors are definitely mixed.

Is this a good thing or a bad thing?

Depends on who you ask. John Locke, one of the most influential philosophers of all time classified properties of objects into primary and secondary qualities and made a clear distinction between them. Primary qualities are those that exist independent of an observer. A book is a book and it has a size, irrespective of how I see it. That's a primary quality. Secondary qualities are those that depend on an observer (color, taste, smell), etc. and these are quite subjective. From an early age, many people have asked themselves whether "my yellow" is the same as "your yellow". We might see different colors but were thaught to call it "yellow". This doesn't change the fact that a yellow book is a book!

Regardless of whether it's true or not, it's conceivable that we all see the world in a slightly different way. There is no clear reason why that would stop us from communicating, building and understanding the world, especially since we can assign numerical, ubiqutous values to explain the sources of subjective experience. This isn't "yellow" - it's an electromagnetic wave with a wavelength of around 600nm. Your red and green receptors in the eye react to it and you "see yellow"! Nowadays, we can describe secondary qualities, such as color as non-disputable properties as well. It is true that it's easier to provide a raw image into a model, have the model do the preprocessing (like EfficientNet does) rather than having a separate function, since you then don't have to think about the preprocessing as much. However - it's not objectively better or worse that ResNet "mixes up" colors. As a matter of fact, this diversity in knowledge can actually lead to some beautiful visualizations down the line. We'll see what that entails in another lesson when we cover the DeepDream algorithm.

Awesome! The model works. Now, let's add a new top to it and re-train the top to perform classification for something outside of the ImageNet set.

Adding a New Top to a Pre-trained Model

When performing transfer learning, you'll typically be loading models without tops, or remove them manually:

# Load without top
# When adding new layers, we also need to define the input_shape
effnet_base = keras.applications.EfficientNetB0(weights='imagenet', 
                                          input_shape=((224, 224, 3)))

# Or load the full model
full_effnet = keras.applications.EfficientNetB0(weights='imagenet', 
                                            input_shape=((224, 224, 3)))
# And then remove X layers from the top
trimmed_effnet = keras.Model(inputs=full_effnet.input, outputs=full_effnet.layers[-3].output)

We'll be going with the first option since it's more convenient. Depending on whether you'd like to fine-tune the convolutional blocks or not - you'll either freeze or won't freeze them. Say we want to use the underlying pre-trained feature maps and freeze the layers so that we only re-train the new classification layers at the top:

effnet_base.trainable = False

You don't need to iterate through the model and set each layer to be trainable or not, though you can. If you'd like to turn off the first n layers, and allow some higher-level feature maps to be fine-tuned, but leave the lower-level ones untouched, you can:

for layer in effnet_base.layers[:-2]:
    layer.trainable = False

Here, we've set all layers in the base model to be untrainable, except for the last two. If we check the model, there are only ~2.5K trainable parameters now:

# ...                
Total params: 4,049,571
Trainable params: 2,560
Non-trainable params: 4,047,011

Now, let's define a Sequential model that'll be put on top of this effnet_base. Fortunately, chaining models in Keras is as easy as making a new model and putting it on top of another one! You can leverage the Functional API and just chain a few new layers on top of a model.

Let's add a GlobalAveragePooling2D layer, some Dropout and a dense classification layer:

gap = keras.layers.GlobalAveragePooling2D()(effnet_base.output, training=False)
do = keras.layers.Dropout(0.2)(gap)
output = keras.layers.Dense(100, activation='softmax')(do)

new_model = keras.Model(inputs=effnet_base.input, outputs=output)

Note: When adding the layers of the EfficientNet, we set the training to False. This puts the network in inference mode instead of training mode and it's a different parameter than the trainable we've set to False earlier. Weight trainability (trainable) is different from mode (training) for all layers except BatchNormalization. This is a crucial step if you wish to unfreeze layers later on as inference mode for BatchNormalization will carry over. BatchNormalization computes moving statistics. When unfrozen, it'll start applying updates to parameters again, and will "undo" the training done before fine-tuning. Since TF 2.0, setting the model's trainable as False also turns training to False but only for BatchNormalization layers.

Alternatively, you can use the Sequential API and call the add() method multiple times, or pass it in in the list of layers:

new_model = keras.Sequential([
    keras.layers.Dense(100, activation='softmax')

new_model.layers[0].trainable = False

This adds the entire model as a layer itself, so it's treated as one entity:

Layer: 0, Trainable: False # Entire EfficientNet model
Layer: 1, Trainable: True
Layer: 2, Trainable: True

If a model is sequential, you can simply add it as:

new_model = keras.Sequential()
new_model.add(base_network.output) # Add unwrapped layers

Though, this fails for non-sequential models. It's advised to use the Functional API for applications like these, since the Sequential API doesn't offer the required flexibility and not all models are sequential (as a matter of fact, since TF 2.4.0, all pre-trained models are functional). Additionally, you can't easily put the base network into inference mode - there's no training argument. The fact that the entire EfficientNet model is a black-box layer doesn't help us work easily with it, so the minor convenience of the Sequential API doesn't really benefit us much, and has several cons.

Back to our model - there are 100 output neurons for the CIFAR100 classes, with a softmax activation. Let's take a look at the trainable layers in the network:

for index, layer in enumerate(new_model.layers):
    print("Layer: {}, Trainable: {}".format(index, layer.trainable))

This results in:

Layer: 0, Trainable: False
Layer: 1, Trainable: False
Layer: 2, Trainable: False
Layer: 235, Trainable: False
Layer: 236, Trainable: False
Layer: 237, Trainable: True
Layer: 238, Trainable: True
Layer: 239, Trainable: True
Layer: 240, Trainable: True
Layer: 241, Trainable: True

Awesome! Let's load in the dataset, preprocess it and re-train the classification layers on it. We'll be using the same CIFAR100 dataset from the last lesson, since it proved to be a difficult one to train a CNN on. The lack of data and limitations of data augmentation made it difficult to create a powerful classifier. Let's see if we can employ transfer learning to help us!

TensorFlow Datasets

We'll be working with the CIFAR100 dataset, again. Though, this time around, we won't be loading it as a bare NumPy array from Keras. We'll be working with TensorFlow Datasets!

Keras' datasets module contains a few datasets, but these are mainly meant for benchmarking and learning and aren't too useful beyond that point. We can use tensorflow_datasets to get access to a much larger corpora of datasets! Additionally, all of the datasets from the module are standardized, so you don't have to bother with different preprocessing steps for every single dataset you're testing your models out on. While it may sound just like a simple convenience, rather than a game-changer - if you train a lot of models, the time it takes to do overhead work gets beyond annoying. The library provides access to datasets from MNIST to Google Open Images (11MB - 565GB), spanning several categories:

  • Audio
  • D4rl
  • Graphs
  • Image
  • Image Classification
  • Object Detection
  • Question Answering
  • Ranking
  • Rlds
  • Robomimic
  • Robotics
  • Text
  • Time Series
  • Text Simplification
  • Vision Language
  • Video
  • Translate
  • etc...

And the list is growing! As of 2022, there are 278 datasets available, the names of which you can obtain via tfds.list_builders(). Additionally, TensorFlow Datasets supports community datasets, with over 700 HuggingFace datasets and the Kubric dataset generator. If you're building a general intelligent system, there's a very good chance there's a public dataset there. For all other purposes - you can download public datasets and work with them, with custom pre-processing steps. Kaggle, HuggingFace and academic repositories are popular choices.

Additionally, in a similar effort, TensorFlow released an amazing GUI tool - Know Your Data, which is still in beta (as of writing) and aims to answer important questions on data corruption (broken images, bad labels, etc.), data sensitivity (does your data contain sensitive content), data gaps (obvious lack of samples), data balance, etc.

A lot of these can help with avoiding bias and data skew - arguably one of the most important things to do when working on projects that can have an impact on other humans.

Another amazing feature is that datasets coming from TensorFlow Datasets are objects, with which you can maximize the performance of your network through pre-fetching, automated optimization, easy transformations, etc.

Note: If you're not a fan of proprietary classes, such as Dataset - you can convert it back into a simple NumPy array for framework-agnosticism. It's advised to work with though.

The module can be installed through:

$ pip install tensorflow_datasets

Once installed, you can access the list of available datasets via:

print(f'Number of Datasets: {len(tfds.list_builders())}')
['abstract_reasoning', 'accentdb', 'aeslc', 'aflw2k3d', ...]
Number of Datasets: 278

Though, you're more likely to use the relevant web pages on the TensorFlow Datasets website, which offers more information, sample images, etc. rather than this list. To load a dataset, you can use the load() function:

dataset, info = tfds.load("cifar100", as_supervised=True, with_info=True)
class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
print('Class names:', class_names)
print('Num of classes:', n_classes)

Datasets can be imported as unsupervised or supervised, and with or without additional information, such as the label names and the number of classes. In the code snippet above, we've loaded in "cifar100" as a supervised dataset (with labels) and information:

Class names: ['apple', 'aquarium_fish', 'baby', ...]
Num of classes: 100

The info.features["label"].names list can come in handy! It's a list of human-readable labels that correspond to the numerical labels in the dataset. We don't have to find a list manually online or type it out this way!

Train, Test and Validation Splits with TensorFlow Datasets

One of the optional arguments you can pass into the load() function is the split argument. The new Split API allows you to define which splits of the dataset you want to split out. By default, for this dataset, it only supports a 'train' and 'test' split - these are the "official" splits for this dataset. There's no valid split.

Note: Each dataset has an "official" split. Some only have the 'train' split, some have a 'train' and 'test' split and some even include a 'validation' split. This is the intended split and only if a dataset supports a split, can you use that split's string alias. If a dataset contains only a 'train' split, you can split that training data into a train/test/valid set without issues.

These correspond to the tfds.Split.TRAIN and tfds.Split.TEST and tfds.Split.VALIDATION enums, which used to be exposed through the API in an earlier version.

You can really slice a Dataset into any arbitrary number of sets, though, we typically do three - train_set, test_set, valid_set:

(test_set, valid_set, train_set), info = tfds.load("cifar100", 
                                           split=["test", "train[0%:20%]", "train[20%:]"],
                                           as_supervised=True, with_info=True)

class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
print(f'Class names: {class_names[:10]}...', ) # ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle']...
print('Num of classes:', n_classes) # Num of classes: 100

print("Train set size:", len(train_set)) # Train set size: 40000
print("Test set size:", len(test_set)) # Test set size: 10000
print("Valid set size:", len(valid_set)) # Valid set size: 10000

We've taken 'test' split and extracted it into the test_set. The slice between 0% and 20% of the 'train' split is assigned to the valid_set and everything beyond 25% is the train_set. This is validated through the sizes of the sets themselves as well.

Instead of percentages, you can use absolute values or a mix of percentage and absolute values:

# Absolute value split
test_set, valid_set, train_set = tfds.load("cifar100", 
                                           split=["test", "train[0:10000]", "train[10000:]"],

print("Train set size:", len(train_set)) # Train set size: 40000
print("Test set size:", len(test_set)) # Test set size: 10000
print("Valid set size:", len(valid_set)) # Valid set size: 10000

# Mixed notation split
# 5000 - 50% (25000) left unassigned
test_set, valid_set, train_set = tfds.load("cifar100", 
                                           split=["test[:2500]", # First 2500 of 'test' are assigned to `test_set`
                                           "train[0:10000]",    # 0-10000 of 'train' are assigned to `valid_set`
                                           "train[50%:]"],        # 50% - 100% of 'train' (25000) assigned to `train_set`

You can additionally do a union of sets, which is less commonly used, as sets are interleaved then:

train_and_test, half_of_train_and_test = tfds.load("cifar100", 
                                split=['train+test', 'train[:50%]+test'],
print("Train+test: ", len(train_and_test))               # Train+test:  60000
print("Train[:50%]+test: ", len(half_of_train_and_test)) # Train[:50%]+test:  35000

These two sets are now heavily interleaved.

Even Splits for N Sets

Again, you can create any arbitrary number of splits, just by adding more splits to the split list:

split=["train[:10%]", "train[10%:20%]", "train[20%:30%]", "train[30%:40%]", ...]

However, if you're creating many splits, especially if they're even - the strings you'll be passing in are very predictable. This can be automated by creating a list of strings, with a given equal interval (such as 10%) instead. For exactly this purpose, the tfds.even_splits() function generates a list of strings, given a prefix string and the desired number of splits:

import tensorflow_datasets as tfds

s1, s2, s3, s4, s5 = tfds.even_splits('train', n=5)
# Each of these elements is just a string
split_list = [s1, s2, s3, s4, s5]
print(f"Type: {type(s1)}, contents: '{s1}'")
# Type: <class 'str'>, contents: 'train[0%:20%]'

for split in split_list:
    test_set = tfds.load("cifar100", 
    print(f"Test set length for Split {split}: ", len(test_set))

This results in:

Test set length for Split train[0%:20%]: 10000
Test set length for Split train[20%:40%]: 10000
Test set length for Split train[40%:60%]: 10000
Test set length for Split train[60%:80%]: 10000
Test set length for Split train[80%:100%]: 10000

Alternatively, you can pass in the entire split_list as the split argument itself, to construct several split datasets outside of a loop:

ts1, ts2, ts3, ts4, ts5 = tfds.load("cifar100", 

Loading CIFAR100 and Data Augmentation

With a working understanding of tfds under your belt - let's load the CIFAR100 dataset in:

import tensorflow_datasets as tfds
import tensorflow as tf

(test_set, valid_set, train_set), info = tfds.load("cifar100", 
                                           split=["test", "train[0%:20%]", "train[20%:]"],
                                           as_supervised=True, with_info=True)

class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
print(f'Class names: {class_names[:10]}...', ) # ['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle']...
print('Num of classes:', n_classes) # Num of classes: 100

print("Train set size:", len(train_set)) # Train set size: 40000
print("Test set size:", len(test_set)) # Test set size: 10000
print("Valid set size:", len(valid_set)) # Valid set size: 10000

Let's take note of a couple of relevant variables in a config dictionary:

config = {
    'TRAIN_SIZE' : len(train_set),
    'BATCH_SIZE' : 32

Now, the CIFAR100 images are significantly different from the ImageNet images! Namely, CIFAR100 images are just 32x32, while our EfficientNet model expects 224x224 images. We'll want to resize the images in any case. We might also want to apply some transformation functions on duplicate images to artificially expand the sample size per class since the dataset doesn't have enough of them. With ImageDataGenerator, we've seen that you have a very loose degree of freedom when it comes to augmentation, and the process is highly automated. When dealing with TensorFlow Datasets, in order to get to use every little bit of optimization they provide - you'll typically use tf.image operations to translate, rotate, etc. images in the preprocess_image() function.

Instead of a dedicated preprocess_image() function, you can simply chain several map() calls with lambda functions, but this approach is significantly less readable and isn't recommended for any larger number of operations. It's better to define functions and call them instead of using lambdas.

The downside is - tf.image is fairly rudimentary. Unlike Keras' rich(er) operations, there are surprisingly enough only a few that can be used for random augmentation, and they offer a smaller degree of freedom. This is in part because tf.image isn't meant to be used for augmentation as much as for general image operations. We'll talk more about Keras augmentations and external libraries later.

Note: A great alternative, making your model more preprocessing-agnostic is to embed preprocessing layers into the model, such as keras.layers.RandomFlip() and keras.layers.RandomRotation(0.2).

Let's define a preprocessing function for each image and its associated label:

def preprocess_image(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    img = tf.image.random_flip_left_right(resized_image)
    img = tf.image.random_brightness(img, 0.4)
    # Preprocess image with model-specific function if it has one
    # processed_image = preprocess_input(resized_image)
    return img, label

Additionally, since we don't want to perform any random transformations to the validation and testing sets, let's define a separate function for those:

def preprocess_test_valid(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    # Preprocess image with model-specific function if it has one
    # processed_image = preprocess_input(resized_image)
    return resized_image, label

And finally, we'll want to apply this function to each image in the sets! This is easily done via the map() function. Since the input into the network also expects batches ((batch_size, 224, 224, 3) instead of (224, 224, 3)) - we'll also batch() the datasets after mapping:

train_set =
test_set =
valid_set =

In this example, we're using - the built-in module for creating data pipelines and optimizing their usage. It's not to be confused with tfds, which is just a library for fetching datasets, while does the heavy lifting on the hardware. The prefetch() function is optional but helps with efficiency and the call lets TensorFlow optimize how to perform prefetching. As the model is training on a single batch, the prefetch() function pre-fetches the next batch so it's not waited upon when the training step is finished. Similarly enough, you could use functions like cache() and interleave() to further optimize IO and data extraction, though, these aren't to be used blindly. If used at an incorrect place or time, they're likely to make your pipelines slower! We'll dedicate a lesson to optimizing data pipelines later. For now - let's just prefetch().

We have a repeat() call on the train_set, which isn't present in other sets. This is analogous to the ImageDataGenerator class, which produces an infinite number of training samples, with random transformations. On each request, the preprocess_image() function we wrote will randomly transform the incoming images, so we have a fresh steady stream of slightly altered data. We don't want to do this for the testing and validation sets, other than making the images the same size and applying the common pre-processing step if there is one (EfficientNetB0 doesn't have an external preprocessing function).

Note: Test-time augmentation is a thing too.

Let's quickly take a look at some of the images from any of the sets:

fig = plt.figure(figsize=(10, 10))

i = 1
for entry in test_set.take(25):

    sample_image = np.squeeze(entry[0].numpy()[0])
    sample_label = class_names[entry[1].numpy()[0]]
    ax = fig.add_subplot(5, 5, i)
    ax.imshow(np.array(sample_image, np.int32))
    ax.set_title(f"Class: {sample_label}")
    i = i+1


Training a Model with Transfer Learning

With the data loaded, preprocessed and split into adequate sets - we can finally train the model on it.

Since we're doing sparse classification, a sparse_categorical_crossentropy loss should work well, and the Adam optimizer is a reasonable default optimizer. Let's compile the model, and train it on a few epochs. It's worth remembering that most of the layers in the network are frozen! We're only training the new classifier on top of the extracted feature maps.

Only once we train the top layers, we may decide to unfreeze the feature extraction layers, and let them fine-tune a bit more. This step is optional but allows you to really squeeze out the best of a model, but naturally takes more resources to do. A good rule of thumb is to try and compare the datasets and guesstimate which levels of the hierarchy you can re-use without re-training, to avoid re-training some of the levels that might be redundant to re-train if your machine can't computationally handle it.

It's actually surprising how well ImageNet weights transfer to most datasets, even if they don't appear to have any remote connection to the domain. We'll especially see this in the next Guided Project on breast cancer classification from histology images. Another good rule of thumb is to always use transfer learning when you can.

Let's compile the network and check its structure:

checkpoint = keras.callbacks.ModelCheckpoint(filepath='effnet_transfer_learning.h5', save_best_only=True)

                  metrics=["accuracy", keras.metrics.SparseTopKCategoricalAccuracy(k=3)])


This is a great time to validate whether you've correctly frozen the layers:

Total params: 4,177,671
Trainable params: 128,100
Non-trainable params: 4,049,571

Only 128k trainable parameters! Naturally, it'll take longer to train this network than a 128k network, since there's a lot more going on - the entirety of the network is there, just not all of it is trainable. It'll take less time than training the entire network, though. Let's train the new network (really, only the top of it) for 10 epochs:

history =, 
                        steps_per_epoch = config['TRAIN_SIZE']/config['BATCH_SIZE'],

Since the train_set is infinite, we'll want to define the steps_per_epoch. This may take some time and is ideally done on a GPU. Depending on how large the model is, and the dataset being fed into it. If you don't have access to a GPU, it's advised to run this code on any of the cloud providers that give you access to a free GPU, such as Google Colab, Kaggle Notebooks, etc. Each epoch can take anywhere from 60 seconds on stronger GPUs to 10 minutes, on weaker ones.

This is the point in which you sit back and go grab a coffee (or tea)! After 10 epochs, the train and validation accuracy are looking good:

Epoch 1/10
1250/1250[==============================] - 97s 76ms/step - loss: 1.9179 - accuracy: 0.5196 - sparse_top_k_categorical_accuracy: 0.7216 - val_loss: 1.3436 - val_accuracy: 0.6324 - val_sparse_top_k_categorical_accuracy: 0.8225
Epoch 10/10
1250/1250[==============================] - 86s 74ms/step - loss: 0.8610 - accuracy: 0.7481 - sparse_top_k_categorical_accuracy: 0.9015 - val_loss: 1.0820 - val_accuracy: 0.6935 - val_sparse_top_k_categorical_accuracy: 0.8651

It has a 69% validation accuracy, and an 86% Top-3 validation accuracy. These are far from the potential of the network though - the classification top has probably done all it could with the feature extractors as they are now. Let's take a look at the learning curves!

Evaluating Before Fine-Tuning

Let's first test this model out, before trying to unfreeze all of the layers. We'll perform some basic evaluation - metric, learning curves and a confusion matrix. Let's start with the metrics:

# 157/157 [==============================] - 10s 65ms/step - loss: 1.0806 - accuracy: 0.6884 - sparse_top_k_categorical_accuracy: 0.8718

~69% on the testing set, and close to the accuracy on the validation set. It has a pretty decent 87% Top-3 accuracy. Looks like our model is generalizing well, but there's still room for improvement. Let's take a look at the learning curves:

The training curves are to be expected - they're pretty short since we only trained for 10 epochs, but they've quickly plateaued, so we probably wouldn't have gotten much better performance with more epochs.

Let's predict the test set and extract the labels from it to produce a classification report and confusion matrix:

y_pred = new_model.predict(test_set)
labels = tf.concat([y for x, y in test_set], axis=0)

Since we have 100 classes - both the classification report and confusion matrix are going to be very large and hardly readable:

from sklearn import metrics
print(metrics.classification_report(labels, np.argmax(y_pred, axis=1)))
      precision    recall  f1-score   support

   0       0.89      0.89      0.89        55
   1       0.76      0.78      0.77        49
   2       0.45      0.64      0.53        45
   3       0.45      0.58      0.50        52

We only have around 50 images per class in the testing set, but we can't really get more than that. It's clear that some classes are better-learned than other classes, such as 0 having significantly higher recall and precision than, say, class 3.

This is actually surprising, since class 0 is apple and 3 is bear! ImageNet has images of bears, and even classifies different types of bears, so you'd expect the network to generalize to bears well, transfering the knowledge from ImageNet. If anything - this speaks about how much of a "prescription" this network effectively has, given how small the images are.

Let's plot the confusion matrix:

from sklearn.metrics import confusion_matrix
import seaborn as sns

matrix = confusion_matrix(labels, y_pred.argmax(axis=1))

# Plot on heatmap
fig, ax = plt.subplots(figsize=(15, 15))
sns.heatmap(matrix, ax=ax, fmt='g')

# Stylize heatmap
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')

# Set ticks
ax.xaxis.set_ticks(np.arange(0, 100, 1))
ax.yaxis.set_ticks(np.arange(0, 100, 1))
ax.xaxis.set_ticklabels(class_names, rotation=90, fontsize=8)
ax.yaxis.set_ticklabels(class_names, rotation=0, fontsize=8)

This results in:

Again - the confusion matrix is pretty large, since we have 100 classes. Though, for the most part, it looks like it's actually generalizing to classses well, albeit, not ideal.

Can we fine-tune this network further? We've replaced and re-trained the top layers concerned with classification of feature maps, but the feature maps themselves might not be ideal! While they are pretty good, these images are simply different from ImageNet, so it's worth taking the time to update the feature extraction layers as well. Let's try unfreezing the convolutional layers and fine-tuning them as well.

Unfreezing Layers - Fine-Tuning a Network Trained with Transfer Learning

Once you've finished re-training the top layers, you can close the deal and be happy with your model. For instance, suppose you got a 95% accuracy - you seriously don't need to go further. However, why not?

If you can squeeze out an additional 1% in accuracy, it might not sound like a lot, but consider the other end of the trade. If your model has a 95% accuracy on 100 samples, it misclassified 5 samples. If you up that to 96% accuracy, it misclassified 4 samples.

The 1% of accuracy translates to a 25% decrease in false classifications.

Whatever you can further squeeze out of your model can actually make a significant difference on the number of incorrect classifications. Again, the images in CIFAR100 are much smaller than ImageNet images, and it's almost as if someone with great eyesight suddenly gained a huge prescription and only saw the world through blurry eyes. The feature maps have to be at least somewhat different!

Let's save the model into a file so we don't lose the progress, and unfreeze/fine-tune a loaded copy, so we don't accidentally mess up the weights on the original one:'effnet_transfer_learning.h5')
loaded_model = keras.models.load_model('effnet_transfer_learning.h5')

Now, we can fiddle around and change the loaded_model without impacting new_model. To start out, we'll want to change the loaded_model from inference mode back to training mode - i.e. unfreeze the layers so that they're trainable again.

Note: Again, if a network uses BatchNormalization (and most do), you'll want to keep them frozen while fine-tuning a network. Since we're not freezing the entire base network anymore, we'll just freeze the BatchNormalization layers instead and allow other layers to be altered.

Let's turn off the BatchNormalization layers so our training doesn't go down the drain:

for layer in loaded_model.layers:
    if isinstance(layer, keras.layers.BatchNormalization):
        layer.trainable = False
        layer.trainable = True

for index, layer in enumerate(loaded_model.layers):
    print("Layer: {}, Trainable: {}".format(index, layer.trainable))

Let's check if that worked:

Layer: 0, Trainable: True
Layer: 1, Trainable: True
Layer: 2, Trainable: True
Layer: 3, Trainable: True
Layer: 4, Trainable: True
Layer: 5, Trainable: False
Layer: 6, Trainable: True
Layer: 7, Trainable: True
Layer: 8, Trainable: False

Awesome! Before we can do anything with the model, to "solidify" the trainability, we have to recompile it. This time around, we'll be using a smaller learning_rate, since we don't want to train the network, but rather just fine-tune what's already there:

checkpoint = keras.callbacks.ModelCheckpoint(filepath='effnet_transfer_learning_finetuned.h5', save_best_only=True)

# Recompile after turning to trainable
                  optimizer=keras.optimizers.Adam(learning_rate=3e-6, decay=(1e-6)), 
                  metrics=["accuracy", keras.metrics.SparseTopKCategoricalAccuracy(k=3)])

history =, 
                        steps_per_epoch = config['TRAIN_SIZE']/config['BATCH_SIZE'],

Again, this may take some time - so sip on another beverage of your choice (stay hydrated) while this runs in the background. The fine-tuning time heavily depends on the architecture you chose to go with but most of the cutting-edge architectures will take some time on a home-grade setup.

Once it finishes, it should reach up to around 80% in accuracy and around 93% on Top-3 accuracy on the validation set:

Epoch 1/15
1250/1250[==============================] - 384s 322ms/step - loss: 0.6567 - accuracy: 0.8024 - sparse_top_k_categorical_accuracy: 0.9356 - val_loss: 0.8687 - val_accuracy: 0.7520 - val_sparse_top_k_categorical_accuracy: 0.9069
Epoch 15/15
1250/1250[==============================] - 377s 322ms/step - loss: 0.3858 - accuracy: 0.8790 - sparse_top_k_categorical_accuracy: 0.9715 - val_loss: 0.7071 - val_accuracy: 0.7971 - val_sparse_top_k_categorical_accuracy: 0.9331

Additionally, if you take a look at the learning curves, they appear to have not plateaued, and we could've probably increased the performance of the model further if we were just to train it for longer:

Note: We probably could've seen further performance increases through further training. Note that training for longer, naturally, takes time. While comparatively low to many other architectures and datasets, a 100 epochs on this dataset took over 10h to train on a home GPU. It's understandable if you're antsy about waiting this long, but unfortunately, 10h isn't even too long to wait for a network to train.

Let's evaluate it and visualize some of the predictions:

# 157/157 [==============================] - 10s 61ms/step - loss: 0.7041 - accuracy: 0.7920 - sparse_top_k_categorical_accuracy: 0.9336
fig = plt.figure(figsize=(10, 10))

i = 1
for entry in test_set.take(25):
    # Predict, get the raw Numpy prediction probabilities
    # Reshape entry to the model's expected input shape
    pred = np.argmax(loaded_model.predict(entry[0].numpy()[0].reshape(1, 224, 224, 3)))

    # Get sample image as numpy array
    sample_image = entry[0].numpy()[0]
    # Get associated label
    sample_label = class_names[entry[1].numpy()[0]]
    # Get human label based on the prediction
    prediction_label = class_names[pred]
    ax = fig.add_subplot(5, 5, i)
    # Plot image and sample_label alongside prediction_label
    ax.imshow(np.array(sample_image, np.int32))
    ax.set_title(f"Actual: {sample_label}\nPred: {prediction_label}")
    i = i+1


A couple of misclassifications, as you'd expect of an 80% accurate model. A raccoon was classified as a shrew which is a mole-like animal (not too far from the truth). A chimpanzee was classified as a lamp (I'd have classified it as a beer bottle). A bus was classified as a pickup truck. This one is curious - as the blue stripe on the bus makes it appear a bit like a pickup truck. It looks like the model understood the blue stripe as the end of the bed of a pickup truck, instead of recognizing the grey top as part of the bus. Finally, a spider was classified as a trout, which is a very different class, but the image is so blurred and small that It's totally understandable.

Our previous model, a custom one, built and trained on this dataset had a 66% Top-1 accuracy, which means we've decreased the error-rate by 39% (from 33 per 100 images to 20 per 100 images).

If you want to obtain the Top-K predictions (not just the most probable one), instead of using argmax() you can utilize TensorFlow's top_k() method:

pred = loaded_model.predict(np.expand_dims(img, 0))
top_probas, top_indices = tf.nn.top_k(pred, k=k)

print(top_probas)  # tf.Tensor([[0.900319   0.07157221 0.00889194]], shape=(1, 3), dtype=float32)
print(top_indices) # tf.Tensor([[66 88 21]], shape=(1, 3), dtype=int32)

If you'd like to display this information alongside the input and the predictions - you could plot the input image, next to a bar chart of the confidence of the network:

for entry in test_set.take(1):
    img = entry[0][0].numpy().astype('int')
    label = entry[1][1]
    # Predict and get top-k classes
    pred = loaded_model.predict(np.expand_dims(img, 0))
    top_probas, top_indices = tf.nn.top_k(pred, k=3)
 	# Convert to NumPy, squeeze and convert to list for ease of plotting
    top_probas = top_probas.numpy().squeeze().tolist()
    # Turn indices into classes
    pred_classes = [] 
    for index in top_indices.numpy().squeeze():
    fig, ax = plt.subplots(1, 2, figsize=(16, 4))
    ax[1].bar(pred_classes, top_probas)

Here, the network is pretty confident that the image is an image of a raccoon. There's a bit of a tiger and a chimpanzee there, but the probabilities are really low:

What about the spider-trout from before?

The network is fairly lost here - all of the probabilities are low, and none of them are right. If you take the top probability at face value and return that class, it sounds like the model was seriously wrong, but when you inspect its confidence, its "line of reasoning" can become a lot clearer. Generally, and especially when returning results to an end-user, you'll want to display the confidence of the model, and potentially other Top-K classes and their probabilities, if the highest probability isn't too high.

For instance, if the top probability is below, say, 50% - you could return multiple classes and their probabilities, such as in the second input image. If the model is fairly certain, you could return just the top class and its probability.

Finally, let's take a look at the confusion matrix compared to the previous one:

While it's still not perfect - it's looking much cleaner!


Transfer learning is the process of transferring already learned knowledge representations from one model to another, when applicable. This concludes the lesson on transfer learning for Image Classification with Keras and Tensorflow. We've started out with taking a look at what transfer learning is and how knowledge representations can be shared between models and architectures.

Then, we've taken a look at some of the most popular and cutting edge models for Image Classification released publically, and piggy-backed on one of them - EfficientNet - to help us in classifying some of our own data. We've taken a look at how to load and examine pre-trained models, how to work with their layers, predict with them and decode the results, as well as how to define your own layers and intertwine them with the existing architecture.

This lesson introduced TensorFlow Datasets, the benefits of using the module and the basics of working with it. Finally, we've loaded and preprocessed a dataset, and trained our new classification top layers on it, before unfreezing the layers and fine-tuning it further through several additional epochs.

Lessson 3/17
You must first start the course before tracking progress.
Mark completed

© 2013-2024 Stack Abuse. All rights reserved.