How To Change Activation Functions in Transformers
In this short tutorial — complete with code and interactive visualizations — we'll learn how to change the hidden activation functions to make models more robust.
Created on June 20|Last edited on January 18
Comment

While transformers have become the de-facto standard for various language, vision, and audio-based tasks, most folks out there are not aware of the various ways we can customize the transformers we often use as base models while creating custom models for fine-tuning, for instance, during a Kaggle competition.
Be it Gradient Checkpointing or Freezing Embeddings, transformers are highly customizable, which helps further boost their performance and robustness on downstream tasks.
In this article, we'll look at one configuration we can use to help make the model more robust, i.e., changing the hidden activation function used in the encoder and pooler. While GeLU (Gaussian Error Linear Units) have become the de facto standard for pre-training, various sources have shown that fine-tuning on real-life data functions like swish works better in preventing over-fitting and making the model more robust
Table of Contents
Let's have a look at how you can switch out activation functions in transformers.
If you're curious about transformers or attention, have a look at these amazing reports:
An Introduction to Attention
Part I in a series on attention. In this installment we look at its origins, its predecessors, and provide a brief example of what's to come
On the Relationship Between Self-Attention and Convolutional Layers
Our submission to ML Reproducibility Challenge 2020. Original paper "On the Relationship between Self-Attention and Convolutional Layers" by Jean-Baptiste Cordonnier, Andreas Loukas & Martin Jaggi, accepted into ICML 2020.
Show Me the Code
A typical custom model built on top of a transformer looks something like this:
import torchfrom torch import nnfrom transformers import AutoModelclass CustomModel(nn.Module):def __init__(self, model_path: str):super().__init__()self.base_model = AutoModel.from_pretrained(model_path)# ... Further Layersself.output = nn.Linear(769, 1)def forward(self, ***) -> torch.Tensor:"""Compute Forward Pass"""features = self.base_model(ids, mask)[0]# ... Further Processingfeatures = self.output(features)return features
We instantiate a Base Model to process the tokens and then process it as required for the downstream task. After processing, we pass it through a single layer to produce the output logits.
The HuggingFace API makes it easy to edit the hidden activation functions by editing the configuration viz.
import torchfrom torch import nnfrom transformers import AutoConfig, AutoModelclass CustomModel(nn.Module):def __init__(self, model_path: str):super().__init__()self.base_config = AutoConfig.from_pretrained(model_name_or_path)self.base_config.update({"hidden_act": "swish",})self.base_model = AutoModel.from_config(self.base_config)# ... Further Layersself.output = nn.Linear(769, 1)def forward(self, ***) -> torch.Tensor:"""Compute Forward Pass"""features = self.base_model(ids, mask)[0]# ... Further Processingfeatures = self.output(features)return features
Notice the difference from the aforementioned snippet. Instead of instantiating a model from the Model Name/Checkpoint path, we create the config and then update the "hidden_act" key. Then we instantiate the model from this edited configuration.
And that is it! That's all it takes. Each model is different and supports a different number of activation functions which you'll have to look up for each available architecture.
This is not the end of it. You can still take a lot of steps to further configure your model! Have a look at the configuration of each model to know more.
Summary
In this article, you saw how you could change the hidden activation functions in Transformers to help make your models more robust and performant.
To see the full suite of W&B features, please check out this short 5 minutes guide. If you want more reports covering the math and from-scratch code implementations, let us know in the comments below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.
Recommended Reading
Preventing The CUDA Out Of Memory Error In PyTorch
A short tutorial on how you can avoid the "RuntimeError: CUDA out of memory" error while using the PyTorch framework.
How To Use Autocast in PyTorch
In this article, we learn how to implement Tensor Autocasting in a short tutorial, complete with code and interactive visualizations, so you can try it yourself.
How To Use GradScaler in PyTorch
In this article, we explore how to implement automatic gradient scaling (GradScaler) in a short tutorial complete with code and interactive visualizations.
PyTorch Dropout for regularization - tutorial
Learn how to regularize your PyTorch model with Dropout, complete with a code tutorial and interactive visualizations
How to Initialize Weights in PyTorch
A short tutorial on how you can initialize weights in PyTorch with code and interactive visualizations.
How To Use GPU with PyTorch
A short tutorial on using GPUs for your deep learning models with PyTorch, from checking availability to visualizing usable.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.