Skip to main content

How To Use the Flatten Layer in Keras

In this article, we explore how to use Flatten layer in Keras and provide a very short tutorial complete with code so that you can easily follow along yourself.
Created on June 21|Last edited on January 23
Many times, while creating neural network architectures, you need to flatten your tensors into a single dimension. For example, if you're processing a batch of images in batches using a convolutional neural network or vision transformer, you're looking at a 4 Dimensional Tensor, i.e. channels×height×width×batch size\large \text{channels} \times \text{height} \times \text{width} \times \text{batch size}.
Now to process the output using a loss function or maybe to feed into another model (for instance, a decoder like in DALL-E) or to feed into multilayer perceptrons (MLPs), you might need to produce a single one-dimensional tensor. How do you do it ?
Well you essentially "flatten" the tensor using the Flatten layer from TensorFlow, viz. tf.keras.layers.Flatten().

Table of Contents




Show me the Code

Tensorflow's Keras backend provides this layer as an easily usable class you can add while defining your models. viz.
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

model = Sequential([
# ... PreProcessing Layers
layers.Conv2D(...),
layers.MaxPooling2D(),
# ... Bunch of Convolutional Layers
layers.Flatten(), # < ---- ⭐️⭐️⭐️⭐️
# ... MLP Layers
])
You can use the Flatten layer directly without adding it as an explicit layer in some Container Abstractions like Sequential and use the functional API instead. For Instance:
from tensorflow import keras
from tensorflow.keras import layers

def build_model(input_shape, input_label):
# .. Process the Inputs using a learned Embedding Layer
embedding = layers.Embedding(input_shape)(input_labels)
# ... Flatten the Tensor
embedding = layers.Flatten()(embedding)
# ... Further process using Convolutional blocks
# ... Generate output logits
output = layers.Dense(1)(feature)
# Create a Model Instance
return keras.models.Model([input_image, input_labels], output)
And that is it! That's all it takes. The Flatten layer is an important layer to know about for any Machine Learning Engineer to have in the toolkit.


To learn more about Keras API, you can check out this video:



Summary

In this short tutorial, we saw how you can use the Flatten layer in Keras and why it might be useful. To see the full suite of W&B features please check out this short 5 minutes guide. If you want more reports covering the math and from-scratch code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.


Iterate on AI agents and models faster. Try Weights & Biases today.