Training a KANFormer: KAN's Are All You Need?
We will dive into a new experimental architecture, replacing the MLP layers in transformers with KAN layers!
Created on May 8|Last edited on May 24
Comment
In this tutorial, we'll introduce an experimental concept in neural network design known as Kolmogorov-Arnold Networks (KANs), and explore their potential integration with transformer architectures, thus creating a "KANFormer."
KANs are inspired by the Kolmogorov-Arnold representation theorem and represent a significant departure from traditional neural networks. Rather than using fixed activation functions and linear weights, KANs incorporate adjustable univariate functions, or splines, at each connection.
This tutorial will focus on the practical aspects of implementing and testing KANs within transformers, rather than delving into the deep mathematical foundations of KANs. Our aim is to provide hands-on experience with KANs, and see how well they perform in place of traditional MLP's within the transformer.
Let's get started.

What we'll cover
The idea The MLP Splines A high level comparison between MLPs and KANs The KANFormerThe architecture The data Sweeps Comparison to a Stock GPT-2 Transformer
The idea
The idea of replacing the multi-layer perceptron (MLP) blocks in transformers with KAN layers is based on the hypothesis that the adaptive flexibility of splines could enhance the model's ability to process and interpret intricate data patterns.
It's crucial to note that the development and application of KANs within transformers are still at a very early, exploratory stage. As such, this approach has not been extensively tested and is not recommended for production use cases at this time. This tutorial aims to spark discussion and interest in KANs by exploring their theoretical benefits and potential applications in advanced neural network structures.
The MLP
Unlike traditional multi-layer perceptrons—which are a mainstay in deep learning due to their capability to approximate nonlinear functions under the universal approximation theorem—KANs introduce a significant architectural shift. MLPs typically utilize fixed activation functions located at nodes (neurons) and linear weights connecting these nodes.
KANs, on the other hand, replace linear weight parameters with learnable univariate functions, called splines, attached to the edges (weights) of the network. Let's talk about those now.
Splines
In the context of KANs, a spline is a piecewise polynomial function which allows for greater flexibility and local control within the model. These splines are parametrized, allowing the network to adapt these functions during training based on the data it encounters.
This fundamental shift from fixed activation functions and linear weights to flexible, parameterized functions at every connection makes KANs particularly adept at modeling complex, nonlinear phenomena on certain problems, in comparison to MLP's.
A high level comparison between MLPs and KANs
The core differences between KANs and MLPs are rooted in their structural and functional approach to processing inputs. In MLPs, each layer's output is computed as a non-linear transformation (through an activation function) of a linear combination of inputs from the previous layer, which are summed and passed through fixed activation functions.
In contrast, KANs utilize a summing operation at the nodes without additional non-linear transformation, while the complexity and non-linear capabilities are provided by the splines on the edges. This arrangement allows KANs to focus on learning the optimal transformations directly through the adjustable splines, enhancing both the model's accuracy and interpretability.
Overall, KANs represent a promising evolution in neural network design, shifting from the traditional neuron-centric approach of MLPs to a more flexible, edge-based model where each connection contributes adaptively to the overall function approximation task. This makes KANs particularly suitable for tasks requiring high accuracy and interpretability, potentially revolutionizing fields such as computational physics and complex system modeling.
The KANFormer
The multi-layer perceptron block acts as a crucial component for introducing non-linearity and depth. Positioned between the self-attention layers, each MLP processes input independently across positions, allowing the model to capture complex data patterns beyond what linear operations alone can achieve.
Typically consisting of two linear layers separated by a non-linear activation function, the MLP significantly enhances the model's representational power. By integrating KAN layers instead, which use spline-based adaptive functions for each connection, we can get a better idea of how KAN layers perform within the Transformer.
The architecture
In my KANFormer model, I used a KAN layer implementation I found on Github, and used this layer replace the standard MLP block typically found in transformers. This adaptation incorporates spline-based functions, aiming to explore different model dynamics compared to traditional MLP setups. I’m building off of Andrej Karpathy’s NanoGPT implementation, which uses variants of GPT-2. The core modification I made is modifying the MLP object to use KAN layers, as opposed to traditional feedforward layers.
Additionally, I chose to add another variant of the KANFormer which uses a bottleneck layer to essentially reduce the dimensionality of the input before it is processed by the KAN layers. This bottleneck approach is designed to reduce the input dimensionality, allowing the KAN layers to operate more efficiently with a reduced computational load. After processing through the KAN layers, the dimensionality is restored to its original size, ensuring that the output can be seamlessly integrated back into the transformer framework. This method not only enhances processing efficiency but also aims to maintain, or even improve, the quality of information flow through the model, potentially leading to more nuanced understanding and generation of text.
For the code, we will be modifying Andrej Karpathy's NanoGPT library to accommodate our specific needs for this experiment, leveraging its capabilities to enhance our model's training and evaluation processes.
Before we are able to test the KANFormer, we will first need to modify the Transformer so it utilizes the KAN layers instead of MLP’s. Additionally, we will add functionality for an alternative configuration, where linear layers are used to downscale the input features to the KAN block, and then subsequently scale them back up to a dimension that fits with the rest of the architecture. Here is our modified MLP module containing KAN layers:
Here is my modified MLP block featuring KAN layers:
class MLP(nn.Module):def __init__(self, config):super().__init__()if config.downsizeMLP:kan_layers = [config.downsize_size, config.kan_hidden_size, config.downsize_size]else:kan_layers = [config.n_embd, config.kan_hidden_size, config.n_embd]self.kan = KAN(layers_hidden=kan_layers,grid_size=config.grid_size,spline_order=config.spline_order,scale_noise=config.scale_noise,scale_base=config.scale_base,scale_spline=config.scale_spline,base_activation=config.base_activation,grid_eps=config.grid_eps,grid_range=config.grid_range,)self.dropout = nn.Dropout(config.dropout)self.downsizeMLP = config.downsizeMLPif self.downsizeMLP:# Define the downsizing linear layer and a ReLU activationself.downsizer = nn.Sequential(nn.Linear(config.n_embd, config.downsize_size),nn.ReLU())# Define the upsizing linear layer and a ReLU activationself.upsizer = nn.Sequential(nn.Linear(config.downsize_size, config.n_embd),nn.ReLU())def forward(self, x):batch_size, seq_length, n_embd = x.shapex = x.view(batch_size * seq_length, n_embd)if self.downsizeMLP:x = self.downsizer(x)x = self.kan(x)if self.downsizeMLP:x = self.upsizer(x)x = x.view(batch_size, seq_length, n_embd)x = self.dropout(x)return x
The new MLP module integrates Kolmogorov-Arnold Networks into the transformer's architecture, replacing traditional MLP blocks. This module includes an optional dimensionality reduction layer, which compresses input features before processing through the KAN layer, potentially improving computational efficiency and focusing on the most relevant features. Note that my experiments will try both options, with and without the dimensionality reduction.
The KAN block itself is quite configurable, with parameters such as grid size, spline order, and noise scale, allowing for lots of experimentation. After the KAN processing, another optional dimensionality expansion layer restores the feature size to align with the rest of the transformer architecture, ensuring seamless integration.
The configuration of the KAN layer is influenced by several parameters set in the config object. If downsizing is enabled, the embedding dimensions are reduced before and after processing through the KAN layer. This involves a 'downsizer', which is a linear layer and a ReLU activation, reducing the dimensionality, followed by the KAN layer, and then an 'upsizer', which restores the dimensions using a similar setup.
The parameter kan_layers determines the sizes of these hidden layers. Additionally, parameters such as grid_size, spline_order, grid_eps, scale_noise, scale_base, and scale_spline are additional parameters for controlling the behavior of the splines within the KAN. These parameters guide the initialization and adaptation of the spline functions during training.
It's important to note that I cannot elaborate extensively on the significance of each parameter within the KAN, as this concept is relatively new and has not undergone substantial experimentation.
To (partially) address this, I will utilize Weights & Biases sweeps to explore and optimize these parameters. W&B sweeps are a powerful tool for hyperparameter tuning that automate the process of testing different combinations of parameters across multiple runs. This method involves defining a range of values for each parameter and then conducting experiments to observe which combinations yield the best performance, effectively streamlining the iterative testing and optimization process.
The data
Next, we will need to prepare a dataset that will be used for training and validating our model. We'll use a dataset available through Hugging Face's datasets library, specifically the "iamtarun/python_code_instructions_18k_alpaca," and process it for use with a transformer model. Here's how we'll do it:
import osimport numpy as npfrom datasets import load_datasetimport tiktoken# Load the dataset from Hugging Facedataset = load_dataset("iamtarun/python_code_instructions_18k_alpaca", split='train')# Combine the relevant columns into a single text columndataset = dataset.map(lambda x: {'text': x['instruction'] + ' ' + x['input'] + ' ' + x['output']})# Get the text data as a list of stringsdata = dataset['text']# Split the data into train and validation setsn = len(data)train_data = data[:int(n*0.9)]val_data = data[int(n*0.9):]# Encode with tiktoken gpt2 bpeenc = tiktoken.get_encoding("gpt2")train_ids = enc.encode_ordinary(' '.join(train_data))val_ids = enc.encode_ordinary(' '.join(val_data))print(f"train has {len(train_ids):,} tokens")print(f"val has {len(val_ids):,} tokens")# Export to bin filestrain_ids = np.array(train_ids, dtype=np.uint16)val_ids = np.array(val_ids, dtype=np.uint16)train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
The provided code loads a python dataset from Hugging Face and splits it into training and validation sets. It then encodes the text using GPT-2's Byte-Pair Encoding through tiktoken for compatibility with language models. The data is tokenized, and token counts are reported. Finally, the tokenized data is saved in a compact binary format, ensuring efficient data storage and retrieval for model training, optimizing both memory usage and processing speed.
To generate your dataset, run the following command:
python data/alpaca/prepare.py
Sweeps
NanoGPT utilizes Python configuration files to manage its settings, making it highly adaptable for experimental modifications. To efficiently leverage his existing infrastructure in our experiments with the KAN, we will simply write a new configuration file for each iteration of the sweep.
This approach allows us to dynamically adjust the model parameters for each sweep iteration, seamlessly integrating new settings without altering the underlying codebase. By writing new config files that specify different parameters based on the sweep's current iteration, we can ensure that each experimental run is both consistent with the NanoGPT framework and tailored to explore the optimal configuration for the KAN.
Here is the script I used to write the config file and run the sweep. Note that this code is not “pretty,” however I do think the true definition of good code is changing in the ChatGPT era, as programmers are now less phased by large amounts of repetitive code, and more attention can be allocated to the overall architecture of the code.
Here is my script for running the sweep:
import osimport subprocessimport wandbimport torch.nn as nndef update_config_file(params, run_id, filename):"""Update the configuration file with new parameter values."""unique_run_name = f"ft-{run_id}"unique_out_dir = f"out-alpaca-{run_id}"config_content = f"""from re import Timport timeimport torch.nn as nnout_dir = '{unique_out_dir}'eval_interval = 100eval_iters = 40wandb_log = Truewandb_project = 'alpaca-gpt2'wandb_run_name = '{unique_run_name}'dataset = 'alpaca'init_from = 'gpt2'always_save_checkpoint = Falsebatch_size = 1gradient_accumulation_steps = 32max_iters = 2000learning_rate = {params['learning_rate']}decay_lr = Truen_embd = 768dropout = 0.1kan_hidden_size = {params['kan_hidden_size']}grid_size = {params['grid_size']}spline_order = {params['spline_order']}scale_noise = {params['scale_noise']}grid_eps = {params['grid_eps']}scale_base = 1.0scale_spline = 1.0base_activation = nn.SiLUgrid_range = [-1, 1]drop = Trueatt_resid = Truemlp_resid = Trueln = TruedownsizeMLP = {params['downsizeMLP']}downsize_size = {params['downsize_size']}"""with open(filename, 'w') as f:f.write(config_content)def train():"""Runs the training script with the updated configuration."""try:run = wandb.init(_dir_='/root/wandbout')config_filename = f'config_{run.id}.py'update_config_file(run.config, run.id, config_filename)# Execute the training script with the specific config filesubprocess.run(['python', 'train_kan.py', config_filename])run.finish()# Optionally, clean up the config file after the run# os.remove(config_filename)except Exception as e:print(str(e))def main():wandb.login()sweep_config = {'method': 'random','metric': {'name': 'val_loss','goal': 'minimize'},'parameters': {'learning_rate': {'values': [1e-5, 3e-5, 6e-5, 1e-4]},'grid_eps': {'values': [0.01, 0.02, 0.05]},'scale_noise': {'values': [0.05, 0.1, 0.15]},'grid_size': {'values': [3, 5, 7]},'spline_order': {'values': [2, 3, 4]},'downsize_size': {'values': [512, 256, 128, 64]},'downsizeMLP': {'values': [True, False]},'kan_hidden_size': {'values': [768, 512, 256, 128]}}}sweep_id = wandb.sweep(sweep_config, _project_='alpaca-exp')wandb.agent(sweep_id, train)if __name__ == "__main__":main()
The sweep mechanism, defined in the sweep_config dictionary, tests various combinations of hyperparameters for the KANFormer, including learning rate, grid size, and spline configuration parameters like spline order and noise scale. Each configuration is dynamically written to a Python file using the update_config_file function, ensuring that each training run is tailored to the specific parameters under test. This file is then used to run the training script, allowing the model to learn under different conditions. The train function orchestrates this process, initializing a WandB run for each iteration, executing the training, and capturing the results, which are automatically logged for analysis. This methodical approach allows for extensive testing of the KAN integration, aiming to optimize the model’s performance through iterative experimentation.
This setup helps in efficiently finding the best model settings by tracking and comparing the results of each run within the W&B dashboard. In the provided script, the parameters responsible for W&B logging include wandb_log, wandb_project, and wandb_run_name. These settings ensure that each training run is logged and identifiable within the wandb interface, allowing for detailed performance tracking and analysis across different configurations.
NanoGPT supports logging the config file to W&B, so its as simple as running the sweep, and W&B handles the rest! This allows us to display our results using something called a 'parallel coordinates' graph, which allows us to show how different hyperparameters affect the performance of a model. Here are the results for my KANFormer sweep!
Overall, the best model on the validation set used the dimensionality reduction MLP layers, with a 2 layer KAN block with 512 units each. I wasn't able to personally find any strong connections between different hyper-parameters, so I definitely think there is plenty of room for further experimentation with different benchmarks and datasets.
Run set
17
Run set
17
One pattern I did notice was not using dimensionality reduction layers prior to our KAN block seemed to increase consistency of our loss, however, some of the best runs used the dimensionality reduction layers.
Run set
17
Comparison to a Stock GPT-2 Transformer
I also ran a similar sweep on the regular transformer, mostly testing out different MLP layer sizes and learning rates, in order to fairly compare the model to the KANFormer. Heres the script I used for this:
import osimport subprocessimport wandbdef update_config_file(params, run_id, filename):"""Update the configuration file with new parameter values."""unique_run_name = f"ft-{run_id}"unique_out_dir = f"out-alpaca-{run_id}"with open(filename, 'w') as f:f.write("from re import T\n")f.write("import time\n")f.write("import torch.nn as nn\n")f.write(f"out_dir = '{unique_out_dir}'\n")f.write(f"eval_interval = 100\n")f.write(f"eval_iters = 40\n")f.write(f"wandb_log = True\n")f.write(f"wandb_project = 'alpaca-exp-gpt2-baseline'\n")f.write(f"wandb_run_name = '{unique_run_name}'\n")f.write(f"dataset = 'alpaca'\n")f.write(f"init_from = 'gpt2'\n")f.write(f"always_save_checkpoint = False\n")f.write(f"batch_size = 1\n")f.write(f"gradient_accumulation_steps = 32\n")f.write(f"max_iters = 2000\n")f.write(f"learning_rate = {params['learning_rate']}\n")f.write(f"decay_lr = True\n")f.write(f"n_embd = 768\n")f.write(f"dropout = 0.1\n")f.write(f"kan_hidden_size = {params['kan_hidden_size']}\n")f.write(f"grid_size = {params['grid_size']}\n")f.write(f"spline_order = {params['spline_order']}\n")f.write(f"scale_noise = {params['scale_noise']}\n")f.write(f"grid_eps = {params['grid_eps']}\n")f.write(f"scale_base = 1.0\n")f.write(f"scale_spline = 1.0\n")f.write(f"base_activation = nn.SiLU\n")f.write(f"grid_range = [-1, 1]\n")f.write(f"drop = True\n")f.write(f"att_resid = True\n")f.write(f"mlp_resid = True\n")f.write(f"ln = True\n")f.write(f"downsizeMLP = {params['downsizeMLP']}\n")f.write(f"downsize_size = {params['downsize_size']}\n")f.flush()f.close()def train():"""Runs the training script with the updated configuration."""try:run = wandb.init()config_filename = f'config_{run.id}.py'update_config_file(run.config, run.id, config_filename)# Execute the training script with the specific config filesubprocess.run(['python', 'train_stock.py', config_filename])run.finish()# Optionally, clean up the config file after the run# os.remove(config_filename)except Exception as e:print(str(e))def main():wandb.login()sweep_config = {'method': 'random','metric': {'name': 'val_loss','goal': 'minimize'},'parameters': {'learning_rate': {'values': [1e-5, 3e-5, 6e-5, 1e-4]},'grid_eps': {'values': [0.1]},'scale_noise': {'values': [0.1]},'grid_size': {'values': [1]},'spline_order': {'values': [1]},'downsize_size': {'values': [1]},'downsizeMLP': {'values': [False]},'kan_hidden_size': {'values': [768*3, 768*4, 768*5, 768*6, 768*7]}}}sweep_id = wandb.sweep(sweep_config, project='alpaca-exp-gpt2-baseline')wandb.agent(sweep_id, train)if __name__ == "__main__":main()
I left most of the parameters for the KANFormer in order to avoid unnecessary code changes, so the KAN parameters can all be ignored, except for the kan_hidden_size, which simply is used to determine the MLP layer hidden layer sizes. Also, you will notice that my update_config_file is slightly different from the KAN sweep, and this is purely a syntactical change, feel free to use either as they function the same!
Here are the results for my stock GPT-2 run:
Run set
35
And just a reminder for this chart, kan_hidden_size is really just the dimenionality of the hidden layer for my regular MLP block, and no KAN layers were used in this test.
Run set
35
Overall the results seemed to be quite similar for this particular dataset and model size. It's really difficult to say how well KANs can perform until we scale them on much larger amounts of data, with larger architectures.
In conclusion, the exploration of Kolmogorov-Arnold Networks within the transformer architecture presents an exciting frontier in neural network design. This tutorial has laid the groundwork for practical testing and experimentation with KANs. However, the true potential of KANs in real-world applications remains to be fully uncovered. Further testing, tuning, and validation are necessary to optimize their performance and understand their capabilities fully. This iterative process of experimentation will help in identifying the most effective configurations and determining the practical applications where KANs can revolutionarily enhance model performance. I hope you enjoyed this tutorial!
Leveraging synthetic data for tabular financial fraud detection
A guide on overcoming data scarcity for fraud detection
I put GPT2-chatbot’s coding skills to the test
A new model has shown up on lmsys, and it looks a lot like GPT-4!
Fine-Tuning Llama-3 with LoRA: TorchTune vs HuggingFace
A battle between the HuggingFace and TorchTune!!!
A Guide to DeepSpeed Zero With the HuggingFace Trainer
A guide for making the most out of your GPU's!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.