How to Flatten Specific Dimensions of NumPy Array


In data manipulation and scientific computing, NumPy stands as one of the most-used libraries as it provides quite a few functionalities. One such operation, "flattening," helps to transform multi-dimensional arrays into a one-dimensional sequence. While flattening an entire array is pretty straightforward, there are times when you might want to selectively flatten specific dimensions to suit the requirements of your data pipeline or algorithm. In this Byte, we'll see various techniques to achieve this more nuanced form of flattening.

NumPy Arrays

NumPy, short for Numerical Python, is a library in Python that provides support for large, multi-dimensional arrays and matrices, along with a collection of mathematical functions to operate on these arrays. NumPy arrays are a key ingredient in scientific computing with Python. They are more efficient and faster compared to Python's built-in list data type, especially when it comes to mathematical operations.

This code shows what a NumPy array can look like:

import numpy as np

# Creating a 2D NumPy array
array_2D = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])


[[1 2 3]
 [4 5 6]
 [7 8 9]]

Why Flatten Some Dimensions of a NumPy Array?

Flattening an array means converting a multidimensional array into a 1D array. But why would you want to flatten just some dimensions of a NumPy array?

Well, there are many scenarios where you might need to do this. For example, in machine learning, often we need to flatten our input data before feeding it into a model. This is because many machine learning algorithms expect input data in a specific format, usually as a 1D array.

But sometimes, you might not want to flatten the entire array. Instead, you might want to flatten specific dimensions of the array while keeping the other dimensions intact. This can be useful in scenarios where you want to maintain some level of the original structure of the data.

How to Flatten a NumPy Array

Flattening a NumPy array is fairly easy to do. You can use the flatten() method provided by NumPy to flatten an array:

import numpy as np

# Creating a 2D NumPy array
array_2D = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Flattening the 2D array
flattened_array = array_2D.flatten()


[1 2 3 4 5 6 7 8 9]

As you can see, the flatten() method has transformed our 2D array into a 1D array.

Get free courses, guided projects, and more

No spam ever. Unsubscribe anytime. Read our Privacy Policy.

But what if we want to flatten only a specific dimension of the array and not the entire array? We'll explore this in the next sections.

Flattening Specific Dimensions of a NumPy Array

Flattening a NumPy array is quite straightforward. But, what if you need to flatten only specific dimensions of an array? This is where the reshape function comes into play.

Let's say we have a 3D array and we want to flatten the last two dimensions, keeping the first dimension as it is. The reshape function can be used to achieve this. Here's a simple example:

import numpy as np

# Create a 3D array
array_3d = np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])

# Reshape the array
flattened_array = array_3d.reshape(array_3d.shape[0], -1)



[[ 1  2  3  4  5  6]
 [ 7  8  9 10 11 12]]

In the above code, the -1 in the reshape function indicates that the size of that dimension is to be calculated automatically. This is based on the size of the array and the size of the other dimensions.

Note: The reshape function does not modify the original array. Instead, it returns a new array that has the specified shape.

Similar Solutions and Use-Cases

Flattening specific dimensions of a NumPy array isn't the only way to manipulate your data. There are other similar solutions you might find useful. For example, the ravel function can also be used to flatten an array. However, unlike reshape, ravel always returns a flattened array.

Additionally, you can use the transpose function to change the order of the array dimensions. This can be useful in cases where you need to rearrange your data for specific operations or visualizations.

These techniques can be particularly useful in data preprocessing for machine learning. For instance, you might need to flatten the input data for a neural network. Or, you might need to transpose your data to ensure that it's in the correct format for a particular library or mathematical function.


In this Byte, we've explored how to flatten specific dimensions of a NumPy array using the reshape function. We've also looked at similar solutions such as ravel and transpose and discussed some use-cases where these techniques can be particularly useful.

While these techniques are powerful tools for data manipulation, they are just the tip of the iceberg when it comes to what you can do with NumPy. So I'd suggest taking a deeper look at the NumPy documentation and see what other interesting features you can discover.

Last Updated: September 21st, 2023
Was this helpful?

Building Your First Convolutional Neural Network With Keras

# python# artificial intelligence# machine learning# tensorflow

Most resources start with pristine datasets, start at importing and finish at validation. There's much more to know. Why was a class predicted? Where was...

David Landup
David Landup

Ā© 2013-2024 Stack Abuse. All rights reserved.