Keras Callbacks: Save and Visualize Prediction on Each Training Epoch

Introduction

Keras is a high-level API, typically used with the TensorFlow library, and has lowered the barrier to entry for many and democratized the creation of Deep Learning models and systems.

When just starting out, a high-level API that abstracts most of the inner-workings helps people get the hang of the basics, and build a starting intuition. Down the line though, practitioners naturally want to build a stronger intuition of what happens under the hood both to gain actionable insight and gain a deeper understanding of how their model learns.

In a lot of cases, it's useful to take a look at the learning process of a Deep Neural Network, testing how it predicts values on each learning epoch, and save the values.

These saved values can be used to visualize the predictions, using libraries like Matplotlib or Seaborn, or can be saved in a log for further analysis in smart systems, or simply analyzed by a human. We typically extract the learning curves of a model to gain a better understanding of how it performs through time - but learning curves reflect the mean loss through time, and you don't get to see how the model performs until it's done training.

Keras has a wonderful feature - callbacks which are snippets of code that are called during training, and can be used to customize the training process. Typically, you use callbacks to save the model if it performs well, stop the training if it's overfitting, or otherwise react to or affect the steps in the learning process.

This makes callbacks the natural choice for running predictions on each batch or epoch, and saving the results, and in this guide - we'll take a look at how to run a prediction on the test set, visualize the results, and save them as images, on each training epoch in Keras.

Note: We'll be building a simple Deep Learning model using Keras in the proceeding sections, but won't put much focus on the implementation or the dataset. This isn't meant to be a guide to building regression models, but a model is needed to properly showcase how the callback works.

If you're interested in reading more about how to build these models and how to get them highly accurate instead of just accurate - read our extensive and detailed Hands-On House Price Prediction - Machine Learning with Python!

Building and Evaluating a Deep Learning Model with Keras

Let's build a simple Keras model for illustrational purposes. We'll speed through this section with minimal focus and attention - this isn't a guide on building regression models. We'll be working with the California Housing Dataset, obtained through Scikit-Learn's datasets module, which is a dataset meant for regression.

Let's go ahead and import the libraries and static methods we'll be using:

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

Now, let's load in the dataset, split it into a training and testing set (we'll split out a validation set later), and visualize the locations of the houses to check if the data's been loaded correctly:

X, y = fetch_california_housing(as_frame=True, return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(x, y)

plt.figure(figsize=(12, 8))
sns.scatterplot(data=x, x='Longitude', y='Latitude', size=y, alpha=0.5, hue=y, palette='magma')
plt.show()

Looks like California! Since the data is loaded correctly, we can define a simple sequential Keras model:

checkpoint = keras.callbacks.ModelCheckpoint("california.h5", save_best_only=True)

model = keras.Sequential([
    keras.layers.Dense(64, activation='relu', kernel_initializer='normal', kernel_regularizer="l2", input_shape=[x_train.shape[1]]),
    keras.layers.Dropout(0.2),
    keras.layers.BatchNormalization(),
    
    keras.layers.Dense(64, activation='relu', kernel_initializer='normal', kernel_regularizer="l2"),
    keras.layers.Dropout(0.2),
    keras.layers.BatchNormalization(),
  
    keras.layers.Dense(1)
])

model.compile(loss='mae',
              optimizer=keras.optimizers.RMSprop(learning_rate=1e-2, decay=0.1),
              metrics=['mae'])
              
history = model.fit(
    x_train, y_train,
    epochs=150,
    batch_size=64,
    validation_split=0.2,
    callbacks=[checkpoint]
)

Here, we've got a simple MLP, with a bit of Dropout and Batch Normalization to battle overfitting, optimized with the RMSprop optimizer and a Mean Absolute Error loss. We've fitted the model for 150 epochs, with a validation split of 0.2, and a ModelCheckpoint callback to save the weights in a file. Running this results in:

...
Epoch 150/150
387/387 [==============================] - 3s 7ms/step - loss: 0.6279 - mae: 0.5976 - val_loss: 0.6346 - val_mae: 0.6042

We could visualize the learning curves to gain some basic insight into how the training went, but it doesn't tell us the whole story - these are just aggregate means over the training and validation sets during training:

model_history = pd.DataFrame(history.history)
model_history['epoch'] = history.epoch

fig, ax = plt.subplots(1, figsize=(8,6))
num_epochs = model_history.shape[0]

ax.plot(np.arange(0, num_epochs), model_history["mae"], 
        label="Training MAE")
ax.plot(np.arange(0, num_epochs), model_history["val_mae"], 
        label="Validation MAE")
ax.legend()

plt.tight_layout()
plt.show()

This results in:

And we can evaluate our model with:

model.evaluate(x_test, y_test)
162/162 [==============================] - 0s 2ms/step - loss: 0.5695 - mae: 0.5451 - mape: 32.2959

As the target variable is measured in multiples of $100.000, which means our network misses the price by up to about $54.000, which is a Mean Absolute Percentage Error of ~32%. Most traditional Machine Learning methods such as Random Forest Regression, even after more extensive data preprocessing for this dataset achieve around $52.000, with tuned hyperparameters - so this is actually a pretty decent result, although it could be improved with more preprocessing, better tuning and different architectures.

The point here wasn't to build a particularly accurate model, but we did choose a dataset using which the model wouldn't converge very quickly, so we can observe its dance around the target variables.

A more illustrative way to evaluate how the model's working ditches the aggregate Mean Absolute Error and Mean Absolute Percentage Error fully, and we can plot a scatter plot of the predicted prices against the actual prices. If they're equal - the plotted markers will follow a straight trajectory diagonally. For reference and scope - we can also plot a diagonal line and evaluate how close each marker is to the line:

test_predictions = model.predict(x_test)
test_labels = y_test

fig, ax = plt.subplots(figsize=(8,4))
plt.scatter(test_labels, test_predictions, alpha=0.6, 
            color='#FF0000', lw=1, ec='black')
lims = [0, 5]

plt.plot(lims, lims, lw=1, color='#0000FF')
plt.ticklabel_format(useOffset=False, style='plain')
plt.xticks(fontsize=18)
plt.yticks(fontsize=18)
plt.xlim(lims)
plt.ylim(lims)

plt.tight_layout()
plt.show()
Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

Running this code results in:

The network over-prices cheaper houses and under-prices more expensive ones - and the estimates have a pretty generous scope (with some predictions on the right being totally out of scope - though, this happens because we haven't cleaned the dataset and many house prices are capped at that value when imported).

This isn't the insight you get from the learning curves, and a network that had the opposite effect - underpricing cheaper houses and overpricing expensive ones might have the same MAE and MAPE but behave totally differently.

What we're also interested in is how the model got here and how these predictions changed through time and the learning process. This is just the end point of the training process, and there was a fair bit of training involved to get here.

Let's go ahead and write a custom callback to add to the list of callbacks in the training process, that will run a prediction on the test set on each epoch, visualize the predictions and save them as an image.

Custom Prediction Keras Callback with Plots

Just like we've used the ModelCheckpoint callback to check whether a model is in its best-performing state on each epoch, and save it into a .h5 file and persist it - we can write a custom callback that'll run predictions, visualize them, and save the images on our disk.

Creating a custom callback boils down to extending the Callback class and overriding any of the methods it provides - the ones you don't override, retain their default behavior:

class PerformancePlotCallback(keras.callbacks.Callback):
       
    def on_train_end(self, logs=None):
      ...
    def on_epoch_begin(self, epoch, logs=None):
      ...
    def on_epoch_end(self, epoch, logs=None):
      ...
    def on_test_begin(self, logs=None):
      ...
    def on_test_end(self, logs=None):
      ...
    # Etc.

Depending on when you'd like to predict using your in-the-training model, you'll choose the appropriate method. A good measure of how it's progressing is an epoch, so at the end of each training epoch, we'll test the model on our test set.

We need a way to provide the test set to the callback, since this is external data. The easiest way to do that is to define a constructor that accepts the test set and evaluates the current model on it, giving you a consistent result:

class PerformancePlotCallback(keras.callbacks.Callback):
    def __init__(self, x_test, y_test):
        self.x_test = x_test
        self.y_test = y_test
        
    def on_epoch_end(self, epoch, logs=None):
        print('Evaluating Model...')
        print('Model Evaluation: ', self.model.evaluate(self.x_test))   

This simple callback accepts the test set of houses and relevant target variables and evaluates itself on each epoch, printing the result to the console, right alongside the usual Keras output.

If we were to instantiate and add this callback to the model, and fit() it again, we'd see a different result from before:

performance_simple = PerformancePlotCallback(x_test, y_test)

# Model definition and compilation...

history = model.fit(
    x_train, y_train,
    epochs=150,
    validation_split=0.2,
    callbacks=[performance_simple]
)

This results in:

Epoch 1/150
387/387 [==============================] - 3s 7ms/step - loss: 1.0785 - mae: 1.0140 - val_loss: 0.9455 - val_mae: 0.8927
Evaluating Model...
162/162 [==============================] - 0s 1ms/step - loss: 0.0528 - mae: 0.0000e+00
Model Evaluation:  [0.05277165770530701, 0.0]
Epoch 2/150
387/387 [==============================] - 3s 7ms/step - loss: 0.9048 - mae: 0.8553 - val_loss: 0.8547 - val_mae: 0.8077
Evaluating Model...
162/162 [==============================] - 0s 1ms/step - loss: 0.0471 - mae: 0.0000e+00
Model Evaluation:  [0.04705655574798584, 0.0]
...

Awesome! The model is evaluating itself on each epoch, on the data we've passed into the callback. Now, let's modify the callback so it visualizes the predictions instead of printing them to the already cluttered output.

To simplify things, we'll get the callback to save the images to a folder, so that we can stitch them together into a video or a Gif later on. We'll also include a model_name in the constructor to help us differentiate models when generating the images and their filenames:

class PerformancePlotCallback(keras.callbacks.Callback):
    def __init__(self, x_test, y_test, model_name):
        self.x_test = x_test
        self.y_test = y_test
        self.model_name = model_name
        
    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.x_test)
        fig, ax = plt.subplots(figsize=(8,4))
        plt.scatter(y_test, y_pred, alpha=0.6, 
            color='#FF0000', lw=1, ec='black')
        
        lims = [0, 5]

        plt.plot(lims, lims, lw=1, color='#0000FF')
        plt.ticklabel_format(useOffset=False, style='plain')
        plt.xticks(fontsize=18)
        plt.yticks(fontsize=18)
        plt.xlim(lims)
        plt.ylim(lims)

        plt.tight_layout()
        plt.title(f'Prediction Visualization Keras Callback - Epoch: {epoch}')
        plt.savefig('model_train_images/'+self.model_name+"_"+str(epoch))
        plt.close()

Here, we create a Matplotlib figure on each epoch, and plot a scatter plot of the predicted prices against the actual prices. Additionally, we've added a diagonal reference line - the closer our scatter plot markers are to the diagonal line, the more accurate our model's predictions were.

The plot is then saved via plt.savefig() with the model's name and the epoch number, alongside an informative title that lets you know which epoch the model is in during training.

Now, let's use this custom callback again, providing a model name in addition to the x_test and y_test sets:

checkpoint = keras.callbacks.ModelCheckpoint("california.h5", save_best_only=True)
performance = PerformancePlotCallback(x_test, y_test, "california_model")

# Model definition and compilation...
              
history = model.fit(
    x_train, y_train,
    epochs=150,
    validation_split=0.2,
    callbacks=[checkpoint, performance]
)

The PerformancePlotCallback goes into full swing, and in the designated folder generates an image of the performance on each epoch. The model_train_images folder is now filled with 150 plots:

You can now use your favorite tool to stitch the images together into a video or a Gif file, or simply peruse them manually. Here's a Gif of the model we've built training on this data:

Conclusion

In this guide, we've built a simple model to predict the price of a house in the California Housing Dataset with okay-ish accuracy. We've then taken a look at how to write a custom Keras callback to test a Deep Learning model's performance and visualize it during training, on each epoch.

We've proceeded to save these images to the disk and created a Gif from them, giving us a different perspective on the training process than the one we get from analyzing the learning curves of a model.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2025 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms