How to Initialize Weights in PyTorch
A short tutorial on how you can initialize weights in PyTorch with code and interactive visualizations.
Created on March 10|Last edited on August 17
Comment
Table of Contents
What We'll Be Covering
In this article, we'll look at how you can initialize weights for the various layers in your PyTorch models.
Unlike Tensorflow, PyTorch doesn't provide an easy interface to initialize weights in various layers (although torch.nn.init is a thing), so it becomes tricky when you want to initialize weights as per a well known technique such as Xavier or He Initialization.
We'll look at how we can initialize weights below, though if you'd like to follow along in an executable Colab, you can do so at:
For a closer look at the various techniques and the motivations involved in weight initialization in neural network you can reference this article.
💡
Initializing Weights To Zero In PyTorch With Class Functions
One of the most popular way to initialize weights is to use a class function that we can invoke at the end of the __init__ function in a custom PyTorch model.
import torch.nn as nnclass Model(nn.Module):# . . .def __init__(self):# .self.apply(self._init_weights)def _init_weights(self, module):if isinstance(module, nn.Linear):module.weight.data.normal_(mean=0.0, std=1.0)if module.bias is not None:module.bias.data.zero_()
This code snippet initializes all weights from a Normal Distribution with mean 0 and standard deviation 1, and initializes all the biases to zero. It's pretty easy to extend this to other layers such as nn.LayerNorm and nn.Embedding.
def _init_weights(self, module):if isinstance(module, nn.Embedding):module.weight.data.normal_(mean=0.0, std=1.0)if module.padding_idx is not None:module.weight.data[module.padding_idx].zero_()elif isinstance(module, nn.LayerNorm):module.bias.data.zero_()module.weight.data.fill_(1.0)
The Pytorch Weight Initialization Experiment
The Weights & Biases charts below are drawn from the provided Colab, to better illustrate weight initialization. Once again, you can find it at:
Here we can see how various standard deviations of the normal distribution differ from each other in terms of performance.
Run set
12
Clearly large values of standard deviation don't lead to good results and most likely lead to a local minimum. Whereas smaller values lead to way better performance.
Summary
In this article, you saw how you can initialize weights for your PyTorch deep learning models and how using Weights & Biases to monitor your metrics can lead to valuable insights.
To see the full suite of W&B features please check out this short 5 minutes guide. If you want more reports covering the math and "from-scratch" code implementations let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other fundamental development topics like GPU Utilization and Saving Models.
Try Weights & Biases
Weights & Biases helps you keep track of your machine learning experiments. Try our tool to log hyperparameters and output metrics from your runs, then visualize and compare results and quickly share findings with your colleagues.
Get started in 5 minutes or run 2 quick experiments on Replit and see how W&B can help organise your work foloow the instructions below:
Instructions:
- Click the green "Run" button below (the first time you click Run, Replit will take approx 30-45 seconds to allocate a machine)
- Follow the prompts in the terminal window (the bottom right pane below)
- You can resize the terminal window (bottom right) for a larger view
Recommended Reading
How To Use GPU with PyTorch
A short tutorial on using GPUs for your deep learning models with PyTorch, from checking availability to visualizing usable.
A Gentle Introduction To Weight Initialization for Neural Networks
An explainer and comprehensive overview of various strategies for neural network weight initialization
PyTorch Dropout for regularization - tutorial
Learn how to regularize your PyTorch model with Dropout, complete with a code tutorial and interactive visualizations
How to save and load models in PyTorch
This article is a machine learning tutorial on how to save and load your models in PyTorch using Weights & Biases for version control.
Image Classification Using PyTorch Lightning and Weights & Biases
This article provides a practical introduction on how to use PyTorch Lightning to improve the readability and reproducibility of your PyTorch code.
How to Compare Keras Optimizers in Tensorflow for Deep Learning
A short tutorial outlining how to compare Keras optimizers for your deep learning pipelines in Tensorflow, with a Colab to help you follow along.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.