Introduction
Suppose you want your Keras model to have some specific behavior during training, evaluation or prediction. For instance, you might want to save your model at every training epoch. One way of doing this is using Callbacks.
In general, Callbacks are functions that are called when some event happens, and are passed as arguments to other functions. In the case of Keras, they are a tool to customize the behavior of your model - be it during training, evaluation or inference. Some applications are logging, model persistence, early stopping or changing the learning rate. This is done by passing a list of Callbacks as arguments for keras.Model.fit()
,keras.Model.evaluate()
or keras.Model.predict()
.
Some common use cases for callbacks are modifying the learning rate, logging, monitoring and early stopping of training. Keras has a number of built-in callbacks, detailed
in the documentation.
However, some more specific applications might require a custom callback. For instance, implementing Learning Rate warm up with a Cosine Decay after a holding period isn't currently built-in, but is widely used and adopted as a scheduler.
Callback Class and Its Methods
Keras has a specific callback class, keras.callbacks.Callback
, with methods that can be called during training, testing and inference on global, batch or epoch level. In order to create custom callbacks, we need to create a subclass and override these methods.
The keras.callbacks.Callback
class has three kinds of methods:
- global methods: called at the beginning or at the end of
fit()
,evaluate()
andpredict()
. - batch-level methods: called at the beginning or at the end of processing a batch.
- epoch-level methods: called at the beginning or at the end of a training batch.
Note: Each method has access to a dict
called logs
. The keys and values of logs
are contextual - they depend on the event which calls the method. Moreover, we have access to the model inside each method through the self.model
attribute.
Let's take a look at three custom callbacks examples - one for training, one for evaluation and one for prediction. Each one will print at each stage what our model is doing and which logs we have access to. This is helpful for understanding what is possible to do with custom callbacks at each stage.
Let's begin by defining a toy model:
import tensorflow as tf
from tensorflow import keras
import numpy as np
model = keras.Sequential()
model.add(keras.layers.Dense(10, input_dim = 1, activation='relu'))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss = "mean_squared_error",
metrics = ["mean_absolute_error"]
)
x = np.random.uniform(low = 0, high = 10, size = 1000)
y = x**2
x_train, x_test = (x[:900],x[900:])
y_train, y_test = (y[:900],y[900:])
Custom Training Callback
Our first callback is to be called during training. Let's subclass the Callback
class:
class TrainingCallback(keras.callbacks.Callback):
def __init__(self):
self.tabulation = {"train":"", 'batch': " "*8, 'epoch':" "*4}
def on_train_begin(self, logs=None):
tab = self.tabulation['train']
print(f"{tab}Training!")
print(f"{tab}available logs: {logs}")
def on_train_batch_begin(self, batch, logs=None):
tab = self.tabulation['batch']
print(f"{tab}Batch {batch}")
print(f"{tab}available logs: {logs}")
def on_train_batch_end(self, batch, logs=None):
tab = self.tabulation['batch']
print(f"{tab}End of Batch {batch}")
print(f"{tab}available logs: {logs}")
def on_epoch_begin(self, epoch, logs=None):
tab = self.tabulation['epoch']
print(f"{tab}Epoch {epoch} of training")
print(f"{tab}available logs: {logs}")
def on_epoch_end(self, epoch, logs=None):
tab = self.tabulation['epoch']
print(f"{tab}End of Epoch {epoch} of training")
print(f"{tab}available logs: {logs}")
def on_train_end(self, logs=None):
tab = self.tabulation['train']
print(f"{tab}Finishing training!")
print(f"{tab}available logs: {logs}")
If any of these methods aren't overridden - default behavior will continue as it has before. In our example - we simply print out the available logs and the level at which the callback is applied, with proper indentation.
Let's take a look at the outputs:
model.fit(
x_train,
y_train,
batch_size=500,
epochs=2,
verbose=0,
callbacks=[TrainingCallback()],
)
Training!
available logs: {}
Epoch 0 of training
available logs: {}
Batch 0
available logs: {}
End of Batch 0
available logs: {'loss': 2172.373291015625, 'mean_absolute_error': 34.79669952392578}
Batch 1
available logs: {}
End of Batch 1
available logs: {'loss': 2030.1309814453125, 'mean_absolute_error': 33.30256271362305}
End of Epoch 0 of training
available logs: {'loss': 2030.1309814453125, 'mean_absolute_error': 33.30256271362305}
Epoch 1 of training
available logs: {}
Batch 0
available logs: {}
End of Batch 0
available logs: {'loss': 1746.2772216796875, 'mean_absolute_error': 30.268001556396484}
Batch 1
available logs: {}
End of Batch 1
available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
End of Epoch 1 of training
available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
Finishing training!
available logs: {'loss': 1467.36376953125, 'mean_absolute_error': 27.10252571105957}
<keras.callbacks.History at 0x7f8bce314c10>
Note that we can follow at each step what the model is doing, and to which metrics we have access. At the end of each batch and epoch, we have access to the in-sample loss function and the metrics of our model.
Custom Evaluation Callback
Now, let's call the Model.evaluate()
method. We can see that at the end of a batch we have access to the loss function and the metrics at the time, and at the end of the evaluation we have access to the overall loss and metrics:
class TestingCallback(keras.callbacks.Callback):
def __init__(self):
self.tabulation = {"test":"", 'batch': " "*8}
def on_test_begin(self, logs=None):
tab = self.tabulation['test']
print(f'{tab}Evaluating!')
print(f'{tab}available logs: {logs}')
def on_test_end(self, logs=None):
tab = self.tabulation['test']
print(f'{tab}Finishing evaluation!')
print(f'{tab}available logs: {logs}')
def on_test_batch_begin(self, batch, logs=None):
tab = self.tabulation['batch']
print(f"{tab}Batch {batch}")
print(f"{tab}available logs: {logs}")
def on_test_batch_end(self, batch, logs=None):
tab = self.tabulation['batch']
print(f"{tab}End of batch {batch}")
print(f"{tab}available logs: {logs}")
res = model.evaluate(
x_test, y_test, batch_size=100, verbose=0, callbacks=[TestingCallback()]
)
Evaluating!
available logs: {}
Batch 0
available logs: {}
End of batch 0
available logs: {'loss': 382.2723083496094, 'mean_absolute_error': 14.069927215576172}
Finishing evaluation!
available logs: {'loss': 382.2723083496094, 'mean_absolute_error': 14.069927215576172}
Custom Prediction Callback
Finally, let's call the Model.predict()
method. Notice that at the end of each batch we have access to the predicted outputs of our model:
class PredictionCallback(keras.callbacks.Callback):
def __init__(self):
self.tabulation = {"prediction":"", 'batch': " "*8}
def on_predict_begin(self, logs=None):
tab = self.tabulation['prediction']
print(f"{tab}Predicting!")
print(f"{tab}available logs: {logs}")
def on_predict_end(self, logs=None):
tab = self.tabulation['prediction']
print(f"{tab}End of Prediction!")
print(f"{tab}available logs: {logs}")
def on_predict_batch_begin(self, batch, logs=None):
tab = self.tabulation['batch']
print(f"{tab}batch {batch}")
print(f"{tab}available logs: {logs}")
def on_predict_batch_end(self, batch, logs=None):
tab = self.tabulation['batch']
print(f"{tab}End of batch {batch}")
print(f"{tab}available logs:\n {logs}")
res = model.predict(x_test[:10],
verbose = 0,
callbacks=[PredictionCallback()])
Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!
Predicting!
available logs: {}
batch 0
available logs: {}
End of batch 0
available logs:
{'outputs': array([[ 7.743822],
[27.748264],
[33.082104],
[26.530678],
[27.939169],
[18.414223],
[42.610645],
[36.69335 ],
[13.096557],
[37.120853]], dtype=float32)}
End of Prediction!
available logs: {}
With these - you can customize the behavior, set up monitoring or otherwise alter the processes of training, evaluation or inference. An alternative to sub-classing is to use the LambdaCallback
.
Using LambaCallback
One of the built-in callbacks in Keras is the LambdaCallback
class. This callback accepts a function which defines how it behaves and what it does! In a sense, it allows you to use any arbitrary function as a callback, thus allowing you to create custom callbacks.
The class has the optional parameters:
-on_epoch_begin
on_epoch_end
on_batch_begin
on_batch_end
on_train_begin
on_train_end
Each parameter accepts a function which is called in the respective model event. As an example, let's make a callback to send an email when the model finishes training:
import smtplib
from email.message import EmailMessage
def send_email(logs):
msg = EmailMessage()
content = f"""The model has finished training."""
for key, value in logs.items():
content = content + f"\n{key}:{value:.2f}"
msg.set_content(content)
msg['Subject'] = f'Training report'
msg['From'] = '[email protected]'
msg['To'] = 'receiver-email'
s = smtplib.SMTP('smtp.gmail.com', 587)
s.starttls()
s.login("[email protected]", "your-gmail-app-password")
s.send_message(msg)
s.quit()
lambda_send_email = lambda logs : send_email(logs)
email_callback = keras.callbacks.LambdaCallback(on_train_end = lambda_send_email)
model.fit(
x_train,
y_train,
batch_size=100,
epochs=1,
verbose=0,
callbacks=[email_callback],
)
To make our custom callback using LambdaCallback
, we just need to implement the function that we want to be called, wrap it as a lambda
function and pass it to the
LambdaCallback
class as a parameter.
A Callback for Visualizing Model Training
In this section, we'll give an example of a custom callback that makes an animation of our model's performance improving during training. In order to do this, we store the values of the logs at the end of each batch. Then, at the end of the training loop, we create an animation using matplotlib
.
In order to enhance the visualization, the loss and the metrics will be plotted in log scale:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation
from IPython import display
class TrainingAnimationCallback(keras.callbacks.Callback):
def __init__(self, duration = 40, fps = 1000/25):
self.duration = duration
self.fps = fps
self.logs_history = []
def set_plot(self):
self.figure = plt.figure()
plt.xticks(
range(0,self.params['steps']*self.params['epochs'], self.params['steps']),
range(0,self.params['epochs']))
plt.xlabel('Epoch')
plt.ylabel('Loss & Metrics ($Log_{10}$ scale)')
self.plot = {}
for metric in self.model.metrics_names:
self.plot[metric], = plt.plot([],[], label = metric)
max_y = [max(log.values()) for log in self.logs_history]
self.title = plt.title(f'batches:0')
plt.xlim(0,len(self.logs_history))
plt.ylim(0,max(max_y))
plt.legend(loc='upper right')
def animation_function(self,frame):
batch = frame % self.params['steps']
self.title.set_text(f'batch:{batch}')
x = list(range(frame))
for metric in self.model.metrics_names:
y = [log[metric] for log in self.logs_history[:frame]]
self.plot[metric].set_data(x,y)
def on_train_batch_end(self, batch, logs=None):
logarithm_transform = lambda item: (item[0], np.log(item[1]))
logs = dict(map(logarithm_transform,logs.items()))
self.logs_history.append(logs)
def on_train_end(self, logs=None):
self.set_plot()
num_frames = int(self.duration*self.fps)
num_batches = self.params['steps']*self.params['epochs']
selected_batches = range(0, num_batches , num_batches//num_frames )
interval = 1000*(1/self.fps)
anim_created = FuncAnimation(self.figure,
self.animation_function,
frames=selected_batches,
interval=interval)
video = anim_created.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()
We'll use the same model as before, but with more training samples:
import tensorflow as tf
from tensorflow import keras
import numpy as np
model = keras.Sequential()
model.add(keras.layers.Dense(10, input_dim = 1, activation='relu'))
model.add(keras.layers.Dense(10, activation='relu'))
model.add(keras.layers.Dense(1))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss = "mean_squared_error",
metrics = ["mean_absolute_error"]
)
def create_sample(sample_size, train_test_proportion = 0.9):
x = np.random.uniform(low = 0, high = 10, size = sample_size)
y = x**2
train_test_split = int(sample_size*train_test_proportion)
x_train, x_test = (x[:train_test_split],x[train_test_split:])
y_train, y_test = (y[:train_test_split],y[train_test_split:])
return (x_train,x_test,y_train,y_test)
x_train,x_test,y_train,y_test = create_sample(35200)
model.fit(
x_train,
y_train,
batch_size=32,
epochs=2,
verbose=0,
callbacks=[TrainingAnimationCallback()],
)
Our output is an animation of the metrics and the loss function as they change through the training process:
Conclusion
In this guide, we've taken a look at the implementation of custom callbacks in Keras.
There are two options for implementing custom callbacks - through subclassing the keras.callbacks.Callback
class, or by using the keras.callbacks.LambdaCallback
class.
We've seen one practical example using LambdaCallback
for sending an email at the end of the training loop, and one example subclassing the Callback
class that creates an animation of the training loop.
Although Keras has many built-in callbacks, knowing how to implement a custom callback can be useful for more specific applications.