Learning Rate Warmup with Cosine Decay in Keras/TensorFlow

The learning rate is an important hyperparameter in deep learning networks - and it directly dictates the degree to which updates to weights are performed, which are estimated to minimize some given loss function. In SGD:

$$
weight_{t+1} = weight_t - lr * \frac{derror}{dweight_t}
$$

With a learning rate of 0, the updated weight is just back to itself - weightt. The learning rate is effectively a knob we can turn to enable or disable learning, and it has major influence over how much learning is happening, by directly controlling the degree of weight updates.

Different optimizers utilize learning rates differently - but the underlying concept stays the same. Needless to say, learning rates have been the object of many studies, papers and practitioner's benchmarks.

Generally speaking, pretty much everyone agrees that a static learning rate won't cut it, and some type of learning rate reduction happens in most techniques that tune the learning rate during training - whether this is a monotonic, cosine, triangular or other types of reduction.

A technique that in the recent years has been gaining foothold is learning rate warmup, which can be paired with practically any other reduction technique.

Learning Rate Warmup

The idea behind learning rate warmup is simple. In the earliest stages of training - weights are far from their ideal states. This means large updates all across the board, which can be seen as "over-corrections" for each weight - where the drastic update of another may negate the update of some other weight, making initial stages of training more unstable.

These changes iron out, but can be avoided by having a small learning rate to begin with, reaching a more stable suboptimal state, and then applying a larger learning rate. You can sort of ease the network into updates, rather than hit it with them.

That's learning rate warmup! Starting with a low (or 0) learning rate and increasing to a starting learning rate (what you'd start with anyway). This increase can follow any function really, but is commonly linear.

After reaching the initial rate, other schedules such as cosine decay, linear reduction, etc. can be applied to progressively lower the rate down until the end of training. Learning rate warmup is usually part of a two-schedule schedule, where LR warmup is the first, while another schedule takes over after the rate has reached a starting point.

In this guide, we'll be implementing a learning rate warmup in Keras/TensorFlow as a keras.optimizers.schedules.LearningRateSchedule subclass and keras.callbacks.Callback callback. The learning rate will be increased from 0 to target_lr and apply cosine decay, as this is a very common secondary schedule. As usual, Keras makes it simple to implement flexible solutions in various ways and ship them with your network.

Note: The implementation is generic and inspired by Tony's Keras implementation of the tricks outlined in "Bag of Tricks for Image Classification with Convolutional Neural Networks".

Learning Rate with Keras Callbacks

The simplest way to implement any learning rate schedule is by creating a function that takes the lr parameter (float32), passes it through some transformation, and returns it. This function is then passed on to the LearningRateScheduler callback, which applies the function to the learning rate.

Now, the tf.keras.callbacks.LearningRateScheduler() passes the epoch number to the function it uses to calculate the learning rate, which is pretty coarse. LR Warmup should be done on each step (batch), not epoch, so we'll have to derive a global_step (across all epochs) to calculate the learning rate instead, and subclass the Callback class to create a custom callback rather than just pass the function, since we'll need to pass in arguments on each call, which is impossible when just passing the function:

def func():
    return ...
    
keras.callbacks.LearningRateScheduler(func)

This approach is favorable when you don't want a high-level of customization and you don't want to interfere with the way Keras treats the lr, and especially if you want to use callbacks like ReduceLROnPlateau() since it can only work with a float-based lr. Let's implement a learning rate warmup using a Keras callback, starting with a convenience function:

def lr_warmup_cosine_decay(global_step,
                           warmup_steps,
                           hold = 0,
                           total_steps=0,
                           start_lr=0.0,
                           target_lr=1e-3):
    # Cosine decay
    learning_rate = 0.5 * target_lr * (1 + np.cos(np.pi * (global_step - warmup_steps - hold) / float(total_steps - warmup_steps - hold)))

    # Target LR * progress of warmup (=1 at the final warmup step)
    warmup_lr = target_lr * (global_step / warmup_steps)

    # Choose between `warmup_lr`, `target_lr` and `learning_rate` based on whether `global_step < warmup_steps` and we're still holding.
    # i.e. warm up if we're still warming up and use cosine decayed lr otherwise
    if hold > 0:
        learning_rate = np.where(global_step > warmup_steps + hold,
                                 learning_rate, target_lr)
    
    learning_rate = np.where(global_step < warmup_steps, warmup_lr, learning_rate)
    return learning_rate

On each step, we calculate the learning rate and the warmup learning rate (both elements of the schedule), with respects to the start_lr and target_lr. start_lr will usually start at 0.0, while the target_lr depends on your network and optimizer - 1e-3 might not be a good default, so be sure to set your target starting LR when calling the method.

If the global_step in the training is higher than the warmup_steps we've set - we use the cosine decay schedule LR. If not, it means that we're still warming up, so the warmup LR is used. If the hold argument is set, we'll hold the target_lr for that number of steps after warmup and before the cosine decay. np.where() provides a great syntax for this:

np.where(condition, value_if_true, value_if_false)

You can visualize the function with:

steps = np.arange(0, 1000, 1)
lrs = []

for step in steps:
  lrs.append(lr_warmup_cosine_decay(step, total_steps=len(steps), warmup_steps=100, hold=10))
plt.plot(lrs)

Now, we'll want to use this function as a part of a callback, and pass the optimizer step as the global_step rather than an element of an arbitrary array - or you can perform the computation within the class. Let's subclass the Callback class:

from keras import backend as K

class WarmupCosineDecay(keras.callbacks.Callback):
    def __init__(self, total_steps=0, warmup_steps=0, start_lr=0.0, target_lr=1e-3, hold=0):

        super(WarmupCosineDecay, self).__init__()
        self.start_lr = start_lr
        self.hold = hold
        self.total_steps = total_steps
        self.global_step = 0
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps
        self.lrs = []

    def on_batch_end(self, batch, logs=None):
        self.global_step = self.global_step + 1
        lr = model.optimizer.lr.numpy()
        self.lrs.append(lr)

    def on_batch_begin(self, batch, logs=None):
        lr = lr_warmup_cosine_decay(global_step=self.global_step,
                                    total_steps=self.total_steps,
                                    warmup_steps=self.warmup_steps,
                                    start_lr=self.start_lr,
                                    target_lr=self.target_lr,
                                    hold=self.hold)
        K.set_value(self.model.optimizer.lr, lr)
Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

First, we define the constructor for the class and keep track of its fields. On each batch that's ended, we'll increase the global step, take note of the current LR and add it to the list of LRs so far. On each batch's beginning - we'll calculate the LR using the lr_warmup_cosine_decay() function and set that LR as the optimizer's current LR. This is done with the backend's set_value().

With that done - just calculate the total steps (length/batch_size*epochs) and take a portion of that number for your warmup_steps:

# If already batched
total_steps = len(train_set)*config['EPOCHS']
# If not batched
#total_steps = len(train_set)/config['BATCH_SIZE']*config['EPOCHS']
# 5% of the steps
warmup_steps = int(0.05*total_steps)

callback = WarmupCosineDecay(total_steps=total_steps, 
                             warmup_steps=warmup_steps,
                             hold=int(warmup_steps/2), 
                             start_lr=0.0, 
                             target_lr=1e-3)

Finally, construct your model and provide the callback in the fit() call:

model = keras.applications.EfficientNetV2B0(weights=None, 
                                            classes=n_classes, 
                                            input_shape=[224, 224, 3])
  
model.compile(loss="sparse_categorical_crossentropy",
                  optimizer='adam',
                  jit_compile=True,
                  metrics=['accuracy'])

At the end of training, you can obtain and visualize the changed LRs via:

lrs = callback.lrs # [...]
plt.plot(lrs)

If you plot the history of a model trained with and without LR warmup - you'll see a distinct difference in the stability of training:

Learning Rate with LearningRateSchedule Subclass

An alternative to creating a callback is to create a LearningRateSchedule subclass, which doesn't manipulate the LR - it replaces it. This approach allows you to prod a bit more into the backend of Keras/TensorFlow, but when used, can't be combined with other LR-related callbacks, such as ReduceLROnPlateau(), which deals with LRs as floating point numbers.

Additionally, using the subclass will require you to make it serializable (overload get_config()) as it becomes a part of the model, if you want to save the model weights. Another thing to note is that the class will expect to work exclusively with tf.Tensors. Thankfully, the only difference in the way we work will be calling tf.func() instead of np.func() since the TensorFlow and NumPy APIs are amazingly similar and compatible.

Let's rewrite out convenience lr_warmup_cosine_decay() function to use TensorFlow operations instead:

def lr_warmup_cosine_decay(global_step,
                           warmup_steps,
                           hold = 0,
                           total_steps=0,
                           start_lr=0.0,
                           target_lr=1e-3):
    # Cosine decay
    # There is no tf.pi so we wrap np.pi as a TF constant
    learning_rate = 0.5 * target_lr * (1 + tf.cos(tf.constant(np.pi) * (global_step - warmup_steps - hold) / float(total_steps - warmup_steps - hold)))

    # Target LR * progress of warmup (=1 at the final warmup step)
    warmup_lr = target_lr * (global_step / warmup_steps)

    # Choose between `warmup_lr`, `target_lr` and `learning_rate` based on whether `global_step < warmup_steps` and we're still holding.
    # i.e. warm up if we're still warming up and use cosine decayed lr otherwise
    if hold > 0:
        learning_rate = tf.where(global_step > warmup_steps + hold,
                                 learning_rate, target_lr)
    
    learning_rate = tf.where(global_step < warmup_steps, warmup_lr, learning_rate)
    return learning_rate

With the convenience function, we can subclass the LearningRateSchedule class. On each __call__() (batch), we'll calculate the LR using the function and return it. You can naturally package the calculation within the subclassed class as well.

The syntax is cleaner than the Callback subclass, primarily because we get access to the step field, rather than keeping track of it on our own, but also makes it somewhat harder to work with class properties - particularly, it makes it hard to extract the lr from a tf.Tensor() into any other type to keep track of in a list. This can be technically circumvented by running in eager mode, but presents an annoyance for keeping track of the LR for debugging purposes and is best avoided:

class WarmUpCosineDecay(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, start_lr, target_lr, warmup_steps, total_steps, hold):
        super().__init__()
        self.start_lr = start_lr
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.hold = hold

    def __call__(self, step):
        lr = lr_warmup_cosine_decay(global_step=step,
                                    total_steps=self.total_steps,
                                    warmup_steps=self.warmup_steps,
                                    start_lr=self.start_lr,
                                    target_lr=self.target_lr,
                                    hold=self.hold)

        return tf.where(
            step > self.total_steps, 0.0, lr, name="learning_rate"
        )

The parameters are the same, and can be calculated in much the same way as before:

# If batched
total_steps = len(train_set)*config['EPOCHS']
# If not batched
#total_steps = len(train_set)/config['BATCH_SIZE']*config['EPOCHS']
# 5% of the steps
warmup_steps = int(0.05*total_steps)

schedule = WarmUpCosineDecay(start_lr=0.0, target_lr=1e-3, warmup_steps=warmup_steps, total_steps=total_steps, hold=warmup_steps)

And the training pipeline only differs in that we set the optimizer's LR to the schedule:

model = keras.applications.EfficientNetV2B0(weights=None, 
                                            classes=n_classes, 
                                            input_shape=[224, 224, 3])
  
model.compile(loss="sparse_categorical_crossentropy",
                  optimizer=tf.keras.optimizers.Adam(learning_rate=schedule),
                  jit_compile=True,
                  metrics=['accuracy'])

history3 = model.fit(train_set,
                    epochs = config['EPOCHS'],
                    validation_data=valid_set)

If you wish to save the model, the WarmupCosineDecay schedule will have to override the get_config() method:

    def get_config(self):
        config = {
          'start_lr': self.start_lr,
          'target_lr': self.target_lr,
          'warmup_steps': self.warmup_steps,
          'total_steps': self.total_steps,
          'hold': self.hold
        }
        return config

Finally, when loading the model, you'll have to pass a WarmupCosineDecay as a custom object:

model = keras.models.load_model('weights.h5', 
                                custom_objects={'WarmupCosineDecay', WarmupCosineDecay})

Conclusion

In this guide, we've taken a look at the intuition behind Learning Rate Warmup - a common technique for manipulating the learning rate while training neural networks.

We've implemented a learning rate warmup with cosine decay, the most common type of LR reduction paired with warmup. You can implement any other function for reduction, or not reduce the learning rate at all - leaving it to other callbacks such as ReduceLROnPlateau(). We've implemented learning rate warmup as a Keras Callback, as well as a Keras Optimizer Schedule and plotted the learning rate through the epochs.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

Project

DeepLabV3+ Semantic Segmentation with Keras

# tensorflow# keras# python# machine learning

Semantic segmentation is the process of segmenting an image into classes - effectively, performing pixel-level classification. Color edges don't necessarily have to be the boundaries...

David Landup
David Landup
Details
Project

Building Your First Convolutional Neural Network With Keras

# artificial intelligence# machine learning# keras# deep learning

Most resources start with pristine datasets, start at importing and finish at validation. There's much more to know. Why was a class predicted? Where was...

David Landup
David Landup
Details

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms