The Softmax Function: The Workhorse of Machine Learning Classification
In this article, we explore how to implement the Softmax function in Python, and how to make good use of it — giving some background and context along the way.
Created on May 5|Last edited on August 3
Comment
Introduction
The softmax function is everywhere in machine learning. After all, it's the de facto activation layer for classification problems. These softmax layers enable the model to focus on one choice only and act like an argmax function that is both continuous and differentiable. It also makes the CrossEntropy function convex if the terms are linear (or regular enough).
In this article, we explore what the softmax function is, how it is used in machine learning, and how to implement it in Python. Here's what we'll cover:
Table of Contents
Introduction My Softmax Journey How Do We Use Softmax for Machine Learning?Implementing in NumPy and PyTorchShould You Use the Softmax Function or Not?Conclusion
Let's get started!
My Softmax Journey
My journey with the softmax function dates to my Master's and my Ph.D. in Transportation Economics and Urban Planning. In this field, we don't call it softmax, but it's just a consequence of the assumptions taken in the modeling of consumer behavior. In fact, I have a chapter explaining the basics of this on my PhD Thesis (check page 34).
💡
In the microeconomic context, the softmax function is derived from the assumption in the error terms of the utility function perceived by the consumers. Without going into too much detail, if we make an error when measuring the utility (or cost) perceived by user j, and we assume the error follows a certain distribution, then computing the probability of that user choosing j, we get:
The user will choose j if it perceives greater utility than other alternatives, "You will buy X if it makes you happier than other alternatives"
If we assume that the errors are distributed "Gumbel and type I extreme values", then the probabilities follow the formula:
The details on how to get there come from the fact that the difference between two Gumbel-distributed errors follows a logistic distribution:
And that's where we get the name Logits when calling the terms that feed the Softmax function!
💡
The logistic distribution has other nice properties that we will not explore here, but it's a very loved microeconomics workhorse also 🏇.
How Do We Use Softmax for Machine Learning?
When solving classification problems with machine learning models, we often want a model that is capable of picking one alternative for a given input:
- Is this a cat or a dog?
- Is this text positive or negative?
- What breed of dog is this?
- Rate this movie between 1-5
All of these problems can be solved using ML. What we want is a model that chooses one of the given alternatives, so how does softmax can help us out here?
Example Case
Generally, ML models output real values (floats) as outputs. We often call these scores. Let's say we want to decide what type of breed is this picture of a dog. For simplicity, let's assume there are 10 possible dog breeds. We will probably create a neural network with a bunch of layers, and at the last layer, we will output 10 scores, one per breed.
The simple approach is picking the highest score. That's it — that's the dog breed we want to predict. This can be performed by the argmax function, returning the breed that has the highest score.
The problem with this approach, is that the argmax function is not differentiable, so we cannot train a model with this activation function 😭
💡
Ok, here is a practical example. Let's start with our breeds:
0: Beagles1: Golden Retrievers.2: German Shepherd Dogs.3: Poodles.4: Bulldogs.5: Rottweilers.6: Beagles.7: Dachshunds.8: German Shorthaired Pointers.9: Rottweilers.
And now, an input image:

The scores are:
scores = [6, 2, 3, 10, 5, 1, 7, 8, 9, 4]
Then "choice" is given by the max value 10, so the model should predict "Poodle." Let's compute the softmax:
from math import expdef softmax(x: list):exp_x = [exp(xi) for xi in x]return [ex/sum(exp_x) for ex in exp_x]for i, x in enumerate(softmax(scores)):print(f"{i}: {x:2.2f}")# output# 0: 0.01# 1: 0.00# 2: 0.00# 3: 0.63. <--- clearly this one is higher than all the others =)# 4: 0.00# 5: 0.00# 6: 0.03# 7: 0.09# 8: 0.23# 9: 0.00
The softmax function forces the model to choose and enables all kinds of cool properties when combined with the negative log-likehood loss (NLLLoss), giving birth to the cross-entropy function.
Relative Difference Invariance of the Softmax
What happens if we add a constant to the scores–let's say 100–and re-compute the probabilities:
scores = [6, 2, 3, 10, 5, 1, 7, 8, 9, 4] + 100for i, x in enumerate(softmax(scores)):print(f"{i}: {x:2.2f}")# output# 0: 0.01# 1: 0.00# 2: 0.00# 3: 0.63. <--- Same as before =)# 4: 0.00# 5: 0.00# 6: 0.03# 7: 0.09# 8: 0.23# 9: 0.00
We get the exact same values as before! What actually happens here is that the logistic distribution only perceives the difference between alternatives, ignoring the scale!
Implementing in NumPy and PyTorch
The softmax function is straightforward to implement on an array-based library like NumPy:
import numpy as npscores = np.array([1,2,3,4])def softmax(scores: np.array):exp_scores = np.exp(scores)return exp_scores / exp_scores.sum()softmax(scores)> array([0.0320586 , 0.08714432, 0.23688282, 0.64391426])
import torchscores = torch.tensor([1,2,3,4])scores.softmax(). # or torch.softmax(scores)> torch.tensor([0.0320586 , 0.08714432, 0.23688282, 0.64391426])
Should You Use the Softmax Function or Not?
Actually, when training neural networks, you should NOT use the softmax function!! You should create a network that outputs scores and feed those scores directly to the cross-entropy loss. Why? Because the cross-entropy already has the Softmax function baked in. I wrote a full article explaining this for PyTorch and Keras:
Conclusion
This little function is very useful in our ML journey, but actually, the right way to use is, is not to use it at all 🤣. There are other use cases of the softmax function, for instance, inside the "transformer" layer to compute the attention mechanisms between tokens, but that is for another article.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.