Don't Use Flatten() - Global Pooling for CNNs with TensorFlow and Keras

Don't Use Flatten() - Global Pooling for CNNs with TensorFlow and Keras

Most practitioners, while first learning about Convolutional Neural Network (CNN) architectures - learn that it's comprised of three basic segments:

  • Convolutional Layers
  • Pooling Layers
  • Fully-Connected Layers

Most resources have some variation on this segmentation, including my own book. Especially online - fully-connected layers refer to a flattening layer and (usually) multiple dense layers.

This used to be the norm, and well-known architectures such as VGGNets used this approach, and would end in:

model = keras.Sequential([
    #...
    keras.layers.MaxPooling2D((2, 2), strides=(2, 2), padding='same'),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'), 
    keras.layers.Dropout(0.5),
    keras.layers.Dense(4096, activation='relu'),
    keras.layers.Dense(n_classes, activation='softmax')
])

Though, for some reason - it's oftentimes forgotten that VGGNet was practically the last architecture to use this approach, due to the obvious computational bottleneck it creates. As soon as ResNets, published just the year after VGGNets (and 7 years ago), all mainstream architectures ended their model definitions with:

model = keras.Sequential([
    #...
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(n_classes, activation='softmax')
])

Flattening in CNNs has been sticking around for 7 years. 7 years! And not enough people seem to be talking about the damaging effect it has on both your learning experience and the computational resources you're using.

Global Average Pooling is preferable on many accounts over flattening. If you're prototying a small CNN - use Global Pooling. If you're teaching someone about CNNs - use Global Pooling. If you're making an MVP - use Global Pooling. Use flattening layers for other use cases where they're actually needed.

Case Study - Flattening vs Global Pooling

Global Pooling condenses all of the feature maps into a single one, pooling all of the relevant information into a single map that can be easily understood by a single dense classification layer instead of multiple layers. It's typically applied as average pooling (GlobalAveragePooling2D) or max pooling (GlobalMaxPooling2D) and can work for 1D and 3D input as well.

Instead of flattening a feature map such as (7, 7, 32) into a vector of length 1536 and training one or multiple layers to discern patterns from this long vector: we can condense it into a (7, 7) vector and classify directly from there. It's that simple!

Note that bottleneck layers for networks like ResNets count in tens of thousands of features, not a mere 1536. When flattening, you're torturing your network to learn from oddly-shaped vectors in a very inefficient manner. Imagine a 2D image being sliced on every pixel row and then concatenated into a flat vector. The two pixels that used to be 0 pixels apart vertically are not feature_map_width pixels away horizontally! While this may not matter too much for a classification algorithm, which favors spatial invariance - this wouldn't be even conceptually good for other applications of computer vision.

Let's define a small demonstrative network that uses a flattening layer with a couple of dense layers:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Flatten(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.summary()

What does the summary look like?

...                                                              
 dense_6 (Dense)             (None, 10)                330       
                                                                 
=================================================================
Total params: 11,574,090
Trainable params: 11,573,898
Non-trainable params: 192
_________________________________________________________________

11.5M parameters for a toy network - and watch the parameters explode with larger input. 11.5M parameters. EfficientNets, one of the best performing networks ever designed work at ~6M parameters, and can't be compared with this simple model in terms of actual performance and capacity to learn from data.

We could reduce this number significantly by making the network deeper, which would introduce more max pooling (and potentially strided convolution) to reduce the feature maps before they're flattened. However, consider that we'd be making the network more complex in order to make it less computationally expensive, all for the sake of a single layer that's throwing a wrench in the plans.

Going deeper with layers should be to extract more meaningful, non-linear relationships between data points, not reducing the input size to cater to a flattening layer.

Here's a network with global pooling:

model = keras.Sequential([
    keras.layers.Input(shape=(224, 224, 3)),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.Conv2D(32, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2), (2, 2)),
    keras.layers.BatchNormalization(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dropout(0.3),
    keras.layers.Dense(10, activation='softmax')
])

model.summary()

Summary?

 dense_8 (Dense)             (None, 10)                650       
                                                                 
=================================================================
Total params: 66,602
Trainable params: 66,410
Non-trainable params: 192
_________________________________________________________________

Much better! If we go deeper with this model, the parameter count will increase, and we might be able to capture more intricate patterns of data with the new layers. If done naively though, the same issues that bound VGGNets will arise.

Going Further - Hand-Held End-to-End Project

Your inquisitive nature makes you want to go further? We recommend checking out our Guided Project: "Convolutional Neural Networks - Beyond Basic Architectures".

Free eBook: Git Essentials

Check out our hands-on, practical guide to learning Git, with best-practices, industry-accepted standards, and included cheat sheet. Stop Googling Git commands and actually learn it!

I'll take you on a bit of time travel - going from 1998 to 2022, highlighting the defining architectures developed throughout the years, what made them unique, what their drawbacks are, and implement the notable ones from scratch. There's nothing better than having some dirt on your hands when it comes to these.

You can drive a car without knowing whether the engine has 4 or 8 cylinders and what the placement of the valves within the engine is. However - if you want to design and appreciate an engine (computer vision model), you'll want to go a bit deeper. Even if you don't want to spend time designing architectures and want to build products instead, which is what most want to do - you'll find important information in this lesson. You'll get to learn why using outdated architectures like VGGNet will hurt your product and performance, and why you should skip them if you're building anything modern, and you'll learn which architectures you can go to for solving practical problems and what the pros and cons are for each.

If you're looking to apply computer vision to your field, using the resources from this lesson - you'll be able to find the newest models, understand how they work and by which criteria you can compare them and make a decision on which to use.

You don't have to Google for architectures and their implementations - they're typically very clearly explained in the papers, and frameworks like Keras make these implementations easier than ever. The key takeaway of this Guided Project is to teach you how to find, read, implement and understand architectures and papers. No resource in the world will be able to keep up with all of the newest developments. I've included the newest papers here - but in a few months, new ones will pop up, and that's inevitable. Knowing where to find credible implementations, compare them to papers and tweak them can give you the competitive edge required for many computer vision products you may want to build.

Conclusion

In this short guide, we've taken a look at an alternative to flattening in CNN architecture design. Albeit short - the guide addresses a common issue when designing prototypes or MVPs, and advises you to use a better alternative to flattening.

Any seasoned Computer Vision Engineer will know and apply this principle, and the practice is taken for granted. Unfortuntately, it doesn't seem to be properly relayed to new practitioners who are just entering the field, and can create sticky habits that take a while to get rid of.

If you're getting into Computer Vision - do yourself a favor and don't use flattening layers before classification heads in your learning journey.

Was this article helpful?

Improve your dev skills!

Get tutorials, guides, and dev jobs in your inbox.

No spam ever. Unsubscribe at any time. Read our Privacy Policy.

David LandupAuthor

Entrepreneur, Software and Machine Learning Engineer, with a deep fascination towards the application of Computation and Deep Learning in Life Sciences (Bioinformatics, Drug Discovery, Genomics), Neuroscience (Computational Neuroscience), robotics and BCIs.

Great passion for accessible education and promotion of reason, science, humanism, and progress.

Project

DeepLabV3+ Semantic Segmentation with Keras

# python# machine learning# tensorflow# computer vision

Semantic segmentation is the process of segmenting an image into classes - effectively, performing pixel-level classification. Color edges don't necessarily have to be the boundaries...

David Landup
David Landup
Details
Project

Building Your First Convolutional Neural Network With Keras

# python# artificial intelligence# machine learning# tensorflow

Most resources start with pristine datasets, start at importing and finish at validation. There's much more to know. Why was a class predicted? Where was...

David Landup
David Landup
Details

© 2013-2022 Stack Abuse. All rights reserved.

DisclosurePrivacyTerms