Grokking: Improved generalization through over-overfitting
One of the most mysterious phenomena in deep learning; Grokking is the tendency of neural networks to improve generalization by sustained overfitting.
Created on June 9|Last edited on June 13
Comment
A few years ago, researchers at OpenAI and Google discovered a fascinating phenomenon known as grokking. And it has significant implications for our understanding of neural network generalization.
Grokking was observed during experiments with small, algorithmically generated datasets, where neural networks displayed an unexpected ability to improve generalization performance to near-perfect levels long after initial "overfitting." This discovery challenges traditional notions of overfitting and learning, highlighting a deeper, more intricate learning process within neural networks.

What we'll cover
What is overfitting?What is grokking?How does Grokfast work?Gradient filtering Speeding up grokking by 50xGrokfast sudo code Grokking a neural network on MNIST Testing with a 'Mini' GPT-2 Overall References
What is overfitting?
In machine learning, overfitting occurs when an algorithm fits too closely or even exactly to its training data, resulting in a model that can’t make accurate predictions or conclusions from any data other than the training data. It's important to note that the definition of overfitting can be a little fuzzy and may differ depending on who you ask.
This will be important as we touch on this later in the article.
What is grokking?
Grokking occurs when a neural network, after initially overfitting, eventually grasps the data's patterns and rules, boosting validation accuracy. The original Grokking experiments used datasets based on binary operations with abstract symbols, making the network learn from symbol interactions, like solving a complex puzzle.
One striking example of grokking involved training a network on the binary operation of division modulo 97. Initially, the training accuracy approached perfection while validation accuracy remained at chance levels. However, after an extended period of optimization, the validation accuracy suddenly improved, indicating that the network had 'grokked' the underlying pattern in the data.

From "GROKKING: GENERALIZATION BEYOND OVERFITTING ON SMALL ALGORITHMIC DATASETS" by Power et. al
A significant limitation of grokking is the extensive training time required to achieve a high level of generalization. To address this issue, a team of researchers from Seoul National University developed a novel algorithm called Grokfast, designed to accelerate the grokking phenomenon, making it more accessible to machine learning practitioners.
How does Grokfast work?
The key idea behind Grokfast is to view the model's learning process as a combination of fast-varying and slow-varying changes in gradients. Fast-varying changes happen quickly during training and primarily enable the model to fit the training data, while slow-varying changes occur more gradually and are crucial for the model’s ability to generalize to new data.
By emphasizing these slow-varying changes, Grokfast aims to speed up the grokking process and help the model generalize faster.
Gradient filtering
To achieve this, the researchers applied a "low-pass filter" to the gradients used to update the model's parameters during training. A low-pass filter, a signal processing tool, removes high-frequency components from a signal, allowing only the low-frequency components to pass through.
In the context of Grokfast, this means filtering out the high-frequency, fast-varying components of the gradients and amplifying the low-frequency, slow-varying components, helping the model focus on changes that are more likely to lead to generalization.
Implementing Grokfast is straightforward and can be done with just a few lines of additional code in standard machine learning frameworks. The researchers proposed two variants of the algorithm: Grokfast-MA, which uses a moving average filter, and Grokfast-EMA, which uses an exponential moving average filter. Both variants work by maintaining a running average of the gradients and adding this average to the current gradients during each training step.
The moving average filter in Grokfast-MA works by taking the average of the gradients over a fixed window of past iterations. By doing this, the algorithm smooths out short-term fluctuations in the gradient values, effectively reducing the high-frequency noise and retaining the slower, more significant changes that contribute to generalization. This process can be thought of as averaging out the "noise" in the gradients, allowing the "signal" that represents meaningful parameter updates to be more prominent. The result is a gradient signal that emphasizes stable, long-term trends over transient, short-term variations.
On the other hand, the exponential moving average (EMA) filter in Grokfast-EMA works by giving exponentially decreasing weights to older gradients. This means that recent gradients have more influence than older ones, but all past gradients still contribute to the average. This approach maintains a smoothed version of the gradient history that adapts more quickly to changes in the gradient signal while still filtering out high-frequency noise. The exponential weighting ensures that the filter responds to recent changes more rapidly than the simple moving average, making it more dynamic while still achieving the low-pass filtering effect.
Speeding up grokking by 50x
Experiments on a wide range of tasks and model architectures showed that Grokfast can speed up the grokking process by up to fifty times, leading to faster generalization. This means that models trained with Grokfast can start performing well on new data much earlier than models trained without it.
By focusing on the low-frequency components of the gradients, Grokfast ensures that the model's parameters are updated in a way that emphasizes meaningful, generalizable patterns in the data, reducing the time and computational resources needed to achieve effective generalization.
Grokfast sudo code
To give a better intuition for how these two methods work, I’ll share some sudo code to show how these gradients are filtered.
# Grokfast-MA: Moving Average# Initialize parametersinitialize model parameters θinitialize queue Q with capacity winitialize scalar factor λ# Training loopfor each iteration t:# Compute current gradientgt = compute_gradient(θ)# Insert current gradient into the queueinsert(Q, gt)# Calculate the moving average of gradientsMA_t = average(Q)# Amplify the low-frequency componentsg_t_hat = gt + λ * MA_t# Update parameters using the modified gradientθ = θ - learning_rate * g_t_hat
And the other variant of Grokfast, which incorporates an exponential moving average:
# Grokfast-EMA: Exponential Moving Average# Initialize parametersinitialize model parameters θinitialize EMA of gradients μ to zeroinitialize scalar momentum αinitialize scalar factor λ# Training loopfor each iteration t:# Compute current gradientgt = compute_gradient(θ)# Update the exponential moving average of gradientsμ = α * μ + (1 - α) * gt# Amplify the low-frequency componentsg_t_hat = gt + λ * μ# Update parameters using the modified gradientθ = θ - learning_rate * g_t_hat
Now that we have a bit of a theoretical understanding of how the algorithm is implemented, we cant now move on to implementing them in Python. Compared to most deep learning research code, I found the implementation quite intuitive. The overall idea is that we keep a running buffer of the ‘grads’ and return them after modifying them in order to maintain the state of the grads. Here are the functions used to modify the gradients:
from collections import dequefrom typing import Dict, Optional, Literalimport torchimport torch.nn as nndef gradfilter_ma(m: nn.Module,grads: Optional[Dict[str, deque]] = None,window_size: int = 100,lamb: float = 5.0,filter_type: Literal['mean', 'sum'] = 'mean',warmup: bool = True,trigger: bool = False, # For ablation study.) -> Dict[str, deque]:if grads is None:grads = {n: deque(maxlen=window_size) for n, p in m.named_parameters() if p.requires_grad}for n, p in m.named_parameters():if p.requires_grad:grads[n].append(p.grad.data.detach()) # .cpu())# Modify the gradients.if not warmup or len(grads[n]) == window_size and not trigger:if filter_type == "mean":avg = sum(grads[n]) / len(grads[n])elif filter_type == "sum":avg = sum(grads[n])else:raise ValueError(f"Unrecognized filter_type {filter_type}")p.grad.data = p.grad.data + avg * lambreturn gradsdef gradfilter_ema(m: nn.Module,grads: Optional[Dict[str, torch.Tensor]] = None,alpha: float = 0.98,lamb: float = 2.0,) -> Dict[str, torch.Tensor]:if grads is None:grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad}for n, p in m.named_parameters():if p.requires_grad:grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha)p.grad.data = p.grad.data + grads[n] * lambreturn grads
Utilizing these functions in code is actually quite simple, as it requires simply passing the model and gradient buffer upon each training iteration. So, to adapt this to any other existing torch training loop, you can follow the following code (credit to the Grokfast Repo for this excellent micro-tutorial on incorporating Grokfast):
# ... in the optimization loop.loss.backwards() # Calculate the gradients.### Option 1: Grokfast (has argument alpha, lamb)grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)### Option 2: Grokfast-MA (has argument window_size, lamb)# grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb)optimizer.step() # Call the optimizer.
Grokking a neural network on MNIST
Ok, now we are ready to move on to training our models, and hopefully be able to see grokking firsthand! The authors of Grokfast provide a really nice repo to reproduce their experiments, which includes a few different models and datasets. Feel free to check out their repo (or mine) if you would like to utilize Weights & Biases for logging.
To get started, go ahead and clone my repo and install the requirements.txt for your environment. Next, we will take a look at the training code, which trains a simple MLP on the MNIST dataset. I chose to work with MNIST for this tutorial, but there are also other datasets like the ‘Mod 97’ dataset, which is a simple synthetic dataset generated by calculating Modulo 97 arithmetic operations.
Here's the training script for training MNIST:
import randomimport timeimport mathfrom argparse import ArgumentParserfrom collections import defaultdictfrom itertools import islicefrom pathlib import Pathimport numpy as npimport matplotlib.pyplot as pltfrom tqdm.auto import tqdmimport torchimport torch.nn as nnimport torchvisionfrom collections import dequefrom typing import Dict, Optional, Literalimport torch.nn as nnfrom grokfast import *import wandbdef cycle(iterable):while True:for x in iterable:yield xdef compute_accuracy(network, dataset, device, N=2000, batch_size=50):"""Computes accuracy of `network` on `dataset`."""with torch.no_grad():N = min(len(dataset), N)batch_size = min(batch_size, N)dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)correct = 0total = 0for x, labels in islice(dataset_loader, N // batch_size):logits = network(x.to(device))predicted_labels = torch.argmax(logits, dim=1)correct += torch.sum(predicted_labels == labels.to(device))total += x.size(0)return (correct / total).item()def compute_loss(network, dataset, loss_function, device, N=2000, batch_size=50):"""Computes mean loss of `network` on `dataset`."""with torch.no_grad():N = min(len(dataset), N)batch_size = min(batch_size, N)dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)loss_fn = loss_function_dict[loss_function](reduction='sum')one_hots = torch.eye(10, 10).to(device)total = 0points = 0for x, labels in islice(dataset_loader, N // batch_size):y = network(x.to(device))if loss_function == 'CrossEntropy':total += loss_fn(y, labels.to(device)).item()elif loss_function == 'MSE':total += loss_fn(y, one_hots[labels]).item()points += len(labels)return total / pointsoptimizer_dict = {'AdamW': torch.optim.AdamW,'Adam': torch.optim.Adam,'SGD': torch.optim.SGD}activation_dict = {'ReLU': nn.ReLU,'Tanh': nn.Tanh,'Sigmoid': nn.Sigmoid,'GELU': nn.GELU}loss_function_dict = {'MSE': nn.MSELoss,'CrossEntropy': nn.CrossEntropyLoss}def main(args):# Initialize wandbwandb.init(project="grokfast_mnist", config=args)log_freq = math.ceil(args.optimization_steps / 150)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")dtype = torch.float32torch.set_default_dtype(dtype)torch.manual_seed(args.seed)torch.cuda.manual_seed_all(args.seed)random.seed(args.seed)np.random.seed(args.seed)# load datasettrain = torchvision.datasets.MNIST(root=args.download_directory, train=True,transform=torchvision.transforms.ToTensor(), download=True)test = torchvision.datasets.MNIST(root=args.download_directory, train=False,transform=torchvision.transforms.ToTensor(), download=True)train = torch.utils.data.Subset(train, range(args.train_points))train_loader = torch.utils.data.DataLoader(train, batch_size=args.batch_size, shuffle=True)assert args.activation in activation_dict, f"Unsupported activation function: {args.activation}"activation_fn = activation_dict[args.activation]# create modellayers = [nn.Flatten()]for i in range(args.depth):if i == 0:layers.append(nn.Linear(784, args.width))layers.append(activation_fn())elif i == args.depth - 1:layers.append(nn.Linear(args.width, 10))else:layers.append(nn.Linear(args.width, args.width))layers.append(activation_fn())mlp = nn.Sequential(*layers).to(device)with torch.no_grad():for p in mlp.parameters():p.data = args.initialization_scale * p.datanparams = sum([p.numel() for p in mlp.parameters() if p.requires_grad])print(f'Number of parameters: {nparams}')wandb.config.update({"nparams": nparams})# create optimizerassert args.optimizer in optimizer_dict, f"Unsupported optimizer choice: {args.optimizer}"optimizer = optimizer_dict[args.optimizer](mlp.parameters(), lr=args.lr, weight_decay=args.weight_decay)# define loss functionassert args.loss_function in loss_function_dictloss_fn = loss_function_dict[args.loss_function]()train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []norms, last_layer_norms, log_steps = [], [], []grads = Nonesteps = 0one_hots = torch.eye(10, 10).to(device)with tqdm(total=args.optimization_steps, dynamic_ncols=True) as pbar:for x, labels in islice(cycle(train_loader), args.optimization_steps):do_log = (steps < 30) or (steps < 150 and steps % 10 == 0) or steps % log_freq == 0if do_log:train_loss = compute_loss(mlp, train, args.loss_function, device, N=len(train))train_acc = compute_accuracy(mlp, train, device, N=len(train))test_loss = compute_loss(mlp, test, args.loss_function, device, N=len(test))test_acc = compute_accuracy(mlp, test, device, N=len(test))train_losses.append(train_loss)train_accuracies.append(train_acc)test_losses.append(test_loss)test_accuracies.append(test_acc)log_steps.append(steps)pbar.set_description("L: {0:1.1e}|{1:1.1e}. A: {2:2.1f}%|{3:2.1f}%".format(train_loss,test_loss,train_acc * 100,test_acc * 100,))# Log to wandbwandb.log({"train_loss": train_loss,"train_acc": train_acc,"test_loss": test_loss,"test_acc": test_acc,"step": steps})y = mlp(x.to(device))if args.loss_function == 'CrossEntropy':loss = loss_fn(y, labels.to(device))elif args.loss_function == 'MSE':loss = loss_fn(y, one_hots[labels])optimizer.zero_grad()loss.backward()trigger = Falseif args.filter == "none":passelif args.filter == "ma":grads = gradfilter_ma(mlp, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger)elif args.filter == "ema":grads = gradfilter_ema(mlp, grads=grads, alpha=args.alpha, lamb=args.lamb)else:raise ValueError(f"Invalid gradient filter type `{args.filter}`")#######optimizer.step()steps += 1pbar.update(1)if __name__ == '__main__':parser = ArgumentParser()parser.add_argument("--label", default="")parser.add_argument("--seed", type=int, default=0)parser.add_argument("--train_points", type=int, default=1000)parser.add_argument("--optimization_steps", type=int, default=100000)parser.add_argument("--batch_size", type=int, default=200)parser.add_argument("--loss_function", type=str, default="MSE")parser.add_argument("--optimizer", type=str, default="AdamW")parser.add_argument("--weight_decay", type=float, default=0.01)parser.add_argument("--lr", type=float, default=1e-3)parser.add_argument("--initialization_scale", type=float, default=8.0)parser.add_argument("--download_directory", type=str, default=".")parser.add_argument("--depth", type=int, default=3)parser.add_argument("--width", type=int, default=200)parser.add_argument("--activation", type=str, default="ReLU")# Grokfastparser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none")parser.add_argument("--alpha", type=float, default=0.99)parser.add_argument("--window_size", type=int, default=100)parser.add_argument("--lamb", type=float, default=5.0)args = parser.parse_args()filter_str = ('_' if args.label != '' else '') + args.filterwindow_size_str = f'_w{args.window_size}'alpha_str = f'_a{args.alpha:.3f}'.replace('.', '')lamb_str = f'_l{args.lamb:.2f}'.replace('.', '')if args.filter == 'none':filter_suffix = ''elif args.filter == 'ma':filter_suffix = window_size_str + lamb_strelif args.filter == 'ema':filter_suffix = alpha_str + lamb_strelse:raise ValueError(f"Unrecognized filter type {args.filter}")optim_suffix = ''if args.weight_decay != 0:optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '')if args.lr != 1e-3:optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}'args.label = args.label + filter_str + filter_suffix + optim_suffixprint(f'Experiment results saved under name: {args.label}')main(args)
This script is fairly straightforward if you have experience with Torch, so I’ll mainly touch on some of the most important components like the usage of Grokfast gradient filtering and some of the hyperparameters.
The moving average filter maintains a window of past gradients and computes their average, which is then added to the current gradient, scaled by a factor (lamb). The exponential moving average filter computes an exponentially decaying average of past gradients, adding this decayed average to the current gradient, also scaled by a factor (lamb). These filters are implemented in the gradfilter_ma and gradfilter_emafunctions, and their usage is controlled by the --filter argument.
A Weights & Biases integration is also a key feature of this script, allowing for logging training progress and results. W&B is initialized at the beginning, and we log various metrics (such as loss and accuracy) during training. This helps in monitoring the training process and analyzing the results effectively, which is of particular interest for analyzing the grokking behavior.
Now that we have some good training infrastructure set up, we can run a “baseline” test to see how well grokking works without the grokfast method.
Run python main_mnist.py --label test to get a baseline result for the model without Grokfast functionality.
Here are my results for the baseline:
Run: baseline
1
This particular training run shows a less exaggerated version of the Grokking initially discovered by OpenAI and Google, where training accuracy is much slower to improve in comparison to training accuracy.
Next, we can move on to implementing GrokFast. The following command will run the script with the moving average filter:
python main_mnist.py --label test --alpha 0.8 --lamb 0.1 --weight_decay 2.0 —filter ma
This run also uses a much higher value for the weight decay, which is a regularization technique used in neural networks that works by adding a penalty to the loss function based on the magnitude of the model's weights. The penalty term is typically proportional to the sum of the squared values of the weights, scaled by the weight decay parameter. This has the effect of encouraging the network to keep its weights small. Weight decay has found to play an important role in grokking, as revealed by the original grokking paper.
We can also run the script with the exponential moving average:
python main_mnist.py --label test --alpha 0.8 --lamb 0.1 --weight_decay 2.0 —filter ema
Here are the results for both the moving average and exponential moving average runs, along with the baseline for comparison:
Run set
9
Here we see that Grokfast has a much quicker path to high validation accuracy in comparison to the baseline test.
From these results, it seems as though Grokfast increases the speed of "generalization" for this particular model and dataset, which is really interesting! Now, earlier, I mentioned that the true definition of overfitting is a bit fuzzy. Some say it's when the model learns the noise in the training data, which has a negative impact on validation set performance. In this case, you see that validation increases nearly "monotonically," meaning that there is rarely a step where validation accuracy goes down as training accuracy improves.
So if the true definition of overfitting is a reduction in validation accuracy due to the model memorizing features in the training set, this run was never truly in the "overfitting" stage. However, if you consider overfitting to simply be a large discrepancy between training and validation accuracy, then this would indeed be considered overfitting.
In my opinion, I think it makes more sense to think about these phases in terms of generalization ability and less about overfitting. For example, there seems to be a slow generalization phase, where validation accuracy improves at a very gradual pace. Additionally, there seems to be rapid generalization phase where validation accuracy improves quite rapidly. Finally, there is a third phase where generalization accuracy consistently degrades (we will see this in the following experiment). This third phase seems much different than the first stage where the model slowly improves, and this is why I have a hard time classifying both stages as overfitting.
Testing with a 'Mini' GPT-2
I also decided to try this new method on a GPT (10 million parameters) training on text data. I started by applying the EMA filter from the very beginning of training, but the results were somewhat poor, and it turns out the authors of Grokfast also came across a similar phenomenon.
Basically, its seems to be that the fast-moving components of the gradient are important in the early stages of training when approaching the overfitting phase, however, once the overfitting phase is reached, the fast moving components of the gradient can be filtered out, accelerating generalization.
In order to accomplish this, I chose to wait until after the 2000th iteration (which is roughly when validation loss levels off) to begin applying the Grokfast method. If you are interested in the full training script, check out the repo here.
I ran 3 different experiments with NanoGPT with the Shakespear-char dataset in the NanoGPT repo, which included a baseline run without any GrokFast gradient filtering, another experiment with Grokfast Applied at the very beginning, and another applying Grokfast after the 2000th iteration.
For this particular experiment, it seemed as though Grokfast gradient filtering is a “cure” for overfitting, as validation stayed steady throughout the duration of the training run (after a large loss spike when beginning the gradient filtering at the 2000th iteration). Additionally, the Grokfast run slightly outperformed the regular backprop run in terms of validation accuracy, although the two were quite close and it’s probably just a result of random chance.
Below shows just the training loss for the 3 runs:
Run set
3
Here is the validation loss chart, which gives a clearer picture on the models ability to generalize on unseen data:
Run set
3
Earlier, I mentioned that the true definition of overfitting was a bit debatable. The chart above demonstrates a particular variant of overfitting, where validation accuracy degrades as training accuracy improves. For this run, I was really impressed with how Grokfast was able to essentially prevent this form of overfitting.
I decided to do another few runs, and applying the Grokfast gradient filtering at different steps. I added 2 more runs, starting gradient filtering at the thousandth iteration, and the 5000th iteration respectively. Here are the results with these 2 runs added:
Run set
5
Interestingly, it seems as though Grokfast must be applied at the right time in order to prevent the overfitting. As shown in the chart, applying Grokfast before and after our original 2000th step procedure, results in overfitting. I'm a bit unsure how to counteract the accuracy 'spikes' when Grokfast is initially applied, but I'm guessing it has something to with the average of the gradients not being initialized, which is something that could probably be solved pretty easily if that is indeed the root cause of the spike.
Finally, I want to mention another interesting finding I've seen in OpenAI's grokking work. The chart I previously showed from OpenAI's paper shows a very delayed improvement in validation accuracy in reference to training accuracy. However, it's interesting to notice that the validation accuracy curve never shows much evidence of declining performance (it also never shows much improvement until late in the run), and simply is slow very slow to improve.

Now, this raises the question, does this chart ever show overfitting? Again, it depends on what you define as overfitting. In my opinion, the above chart does not show overfitting at any stage, rather it simply shows a very slowly improving model.
However, if the validation accuracy were to be trending in the opposite direction, I would be more inclined to believe that the model was in fact "overfitting" (but again this is just my opinion of what overfitting is). This doesn't mean that I don't think the Grokking work is interesting (quite the opposite), however, I think it's a bit of a stretch to say that these models are going from overfitting to generalizing.
Additionally, from what I can tell, the Grokfast technique seems to have plenty of application beyond Grokking research. Its seems like this particular idea has a lot of potential for reducing overfitting in general, and is really exciting overall.
Overall
The discovery of grokking has fundamentally shifted our understanding of neural network training dynamics and generalization. By uncovering the capacity for networks to learn underlying data patterns after periods of overfitting, researchers have made a significant improvement in our understanding of deep learning models.
Grokfast, with its innovative gradient filtering techniques, provides a practical enhancement to this phenomenon, significantly reducing the training time required for achieving high generalization performance. This development not only refines the theoretical framework of overfitting and generalization but also offers tangible benefits for practitioners seeking efficient and effective training methodologies. I really recommend checking out the Grokfast paper here, and if you are interested in the code, feel free to check out their repo, or my repo which includes W&B logging.
6 "gotchas" in machine learning—and how to avoid them
ML is hard and you can't plan for everything. Here are a few things I've learned and a few tips to avoid common missteps
How to fine-tune Phi-3 Vision on a custom dataset
Here's how to fine tune a state of the art multimodal LLM on a custom dataset
Training a KANFormer: KAN's Are All You Need?
We will dive into a new experimental architecture, replacing the MLP layers in transformers with KAN layers!
Creating videos from static images with Stable Video Diffusion
The model known for generating images has been upgraded to handle video! We will cover the basics of the model, and also generate some sample videos!
References
Add a comment
Tags: Articles
Iterate on AI agents and models faster. Try Weights & Biases today.