Don't Use Flatten() - Global Pooling for CNNs with TensorFlow and Keras

Most practitioners, while first learning about Convolutional Neural Network (CNN) architectures - learn that it's comprised of three basic segments:

  • Convolutional Layers
  • Pooling Layers
  • Fully-Connected Layers

Most resources have some variation on this segmentation, including my own book. Especially online - fully-connected layers refer to a flattening layer and (usually) multiple dense layers.

This used to be the norm, and well-known architectures such as VGGNets used this approach, and would end in:

model = keras.Sequential([
    #...
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Though, for some reason - it's oftentimes forgotten that VGGNet was practically the last architecture to use this approach, due to the obvious computational bottleneck it creates. As soon as ResNets, published just the year after VGGNets (and 7 years ago), all mainstream architectures ended their model definitions with:

model = keras.Sequential([
    #...
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Flattening in CNNs has been sticking around for 7 years. 7 years! And not enough people seem to be talking about the damaging effect it has on both your learning experience and the computational resources you're using.

Global Average Pooling is preferable on many accounts over flattening. If you're prototyping a small CNN - use Global Pooling. If you're teaching someone about CNNs - use Global Pooling. If you're making an MVP - use Global Pooling. Use flattening layers for other use cases where they're actually needed.

Case Study - Flattening vs Global Pooling

Global Pooling condenses all of the feature maps into a single one, pooling all of the relevant information into a single map that can be easily understood by a single dense classification layer instead of multiple layers. It's typically applied as average pooling (GlobalAveragePooling2D) or max pooling (GlobalMaxPooling2D) and can work for 1D and 3D input as well.

Instead of flattening a feature map such as (7, 7, 32) into a vector of length 1536 and training one or multiple layers to discern patterns from this long vector: we can condense it into a (7, 7) vector and classify directly from there. It's that simple!

Note that bottleneck layers for networks like ResNets count in tens of thousands of features, not a mere 1536. When flattening, you're torturing your network to learn from oddly-shaped vectors in a very inefficient manner. Imagine a 2D image being sliced on every pixel row and then concatenated into a flat vector. The two pixels that used to be 0 pixels apart vertically are not feature_map_width pixels away horizontally! While this may not matter too much for a classification algorithm, which favors spatial invariance - this wouldn't be even conceptually good for other applications of computer vision.

Let's define a small demonstrative network that uses a flattening layer with a couple of dense layers:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

What does the summary look like?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5M parameters for a toy network - and watch the parameters explode with larger input. 11.5M parameters. EfficientNets, one of the best performing networks ever designed, works at ~6M parameters, and can't be compared with this simple model in terms of actual performance and capacity to learn from data.

We could reduce this number significantly by making the network deeper, which would introduce more max pooling (and potentially strided convolution) to reduce the feature maps before they're flattened. However, consider that we'd be making the network more complex in order to make it less computationally expensive, all for the sake of a single layer that's throwing a wrench in the plans.

Going deeper with layers should be to extract more meaningful, non-linear relationships between data points, not reducing the input size to cater to a flattening layer.

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

Here's a network with global pooling:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Summary?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Much better! If we go deeper with this model, the parameter count will increase, and we might be able to capture more intricate patterns of data with the new layers. If done naively though, the same issues that bound VGGNets will arise.

Conclusion

In this short guide, we've taken a look at an alternative to flattening in CNN architecture design. Albeit short - the guide addresses a common issue when designing prototypes or MVPs, and advises you to use a better alternative to flattening.

Any seasoned Computer Vision Engineer will know and apply this principle, and the practice is taken for granted. Unfortunately, it doesn't seem to be properly relayed to new practitioners who are just entering the field, and can create sticky habits that take a while to get rid of.

If you're getting into Computer Vision - do yourself a favor and don't use flattening layers before classification heads in your learning journey.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

© 2013-2024 Stack Abuse. All rights reserved.

AboutDisclosurePrivacyTerms