What Are Intrinsic Dimensions? The Secret Behind LoRA
This article provides a brief overview of intrinsic dimensions and how they enable Low-Rank Domain Adaptation. We also provide code samples which use Weights & Biases for interactive visualizations.
Created on November 24|Last edited on December 12
Comment
Recently, LoRA-based fine-tuning methods have become popular as a parameter-efficient way to fine-tune large language models. These methods are based on the discovery outlined in the paper "Measuring the Intrinsic Dimension of Objective Landscapes" by Chunyuan Li, Heerad Farkhoor, Rosanne Liu and Jason Yosinski. In this article, we'll look into the concept of intrinsic dimensions and show why we can use them to exploit the low-rank nature of the large language models!
To follow along this article, please refer to the following Colab Notebook:
Table of Contents
What Are Intrinsic Dimensions?😎 Intrinsic Dimensions Paper Methodology👨💻 Coding Intrinsic Dimensions📊 The Results🎬 Conclusion
What Are Intrinsic Dimensions?
Intrinsic dimensions represent the minimum number of parameters needed to achieve optimal performance in a machine learning model. This concept challenges the traditional assumption that all parameters contribute equally to performance and posits that a smaller, carefully chosen subset holds the key to unlocking a model's full potential.
This concept offers several key benefits:
- Efficiency: Focusing on intrinsic dimensions allows for a significantly reduced parameter count, leading to faster training times and lower computational requirements. This is particularly advantageous for large language models, where parameter spaces are vast.
- Fine-tuning: Pre-training large language models implicitly reduces their intrinsic dimension, explaining the success of fine-tuning. These models require less data to adapt to new tasks due to their minimized intrinsic dimension.
- LoRA Foundation: Intrinsic dimensions form the basis of Low-Rank Domain Adaptation (LoRA), a technique for efficiently adapting large language models to specific domains. LoRA leverages the concept to achieve rapid development of specialized language models without extensive data.
With that behind up, let's dive into the paper "Measuring the Intrinsic Dimension of Objective Landscapes".
😎 Intrinsic Dimensions Paper Methodology
In this paper, the authors try to train models not in their native parameter space (i.e. the parameter space consisting of all the parameters) but in smaller randomly oriented subspaces. The process of training a neural network can be thought of as traveling along some loss landscape, and trying to find a low point in some valley. All operations after initializing the model, such as forward and backward propagation, can be thought of as taking a step in this objective landscape.

The authors take a random subspace and then compute the gradients only along this smaller subspace. Then, they slowly increase the dimension of this space, note the dimension at which solutions first start appearing, and define this to be the intrinsic dimension.
This of course, has great implications on what it means to train a model! The model performance peaks at some dimension that is less than the full parameter space, i.e. we only need a small number of dimensions to optimize the model. We can loosely define this Intrinsic Dimensionality paradigm as follows:
- Typical Paradigm: If a neural network has parameters, then we compute the gradient in the entire parameter space .
- Intrinsic Paradigm: If a neural network has parameters, then we choose a random -dimensional subspace () of the full parameter space and optimise in this subspace. Typically, .
But why are talking about this? For one, it helps us to understand why fine-tuning works. Fine-tuning as a concept is wonderful in its own right. We essentially tune hundreds of millions of parameters with only thousands of examples.
And why does this work at all? If we analyze fine-tuning through the lens of intrinsic dimension, we can show empirically why fine-tuning works. The authors of Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning even proved that common pre-trained models have very low intrinsic dimensions and that pre-training implicitly minimizes intrinsic dimension.
This concept is the backbone of why LoRA (Low Rank Domain Adaptation) works. If you want to read more about that, you can refer to my other article here:
👨💻 Coding Intrinsic Dimensions
Computing the intrinsic dimension is usually done via random subspace sampling. This means that, for each gradient update state, we arbitrarily choose a smaller subspace with fewer dimensions and only consider the gradient along this randomly chosen subspace to update the parameters. Upon increasing this subspace dimension, one eventually reaches a point where the model performance begins to plateau. This dimension is then said to be the intrinsic dimension.
In this article, we shall look into implementing random subspace sampling using the optax library. All optax gradient transformations have a init function and a update function. Both of these functions are called when creating the train_state.TrainState within the create and the apply_gradients method respectively. Thus, if we create a custom GradientTransformation object, we can chain it after some optimizer, viz.
optimizer = combine.chain(optax.adam(learning_rate),random_subspace_gradients(subspace_dim = 16))
This is the main dev work behind the article as most of the other training boilerplate code is copied from earlier work done with Soumik Rakshit which you can read here:
class RandomSubSpaceState(NamedTuple):"""State containing PRNGKey for `random_subspace_gradients`."""rng_key: chex.PRNGKeydef random_subspace_gradients(seed: int,subspace_dim: int) -> base.GradientTransformation:"""Computes Gradients along some random subspaceReferences:[Li et al, 2018](https://arxiv.org/abs/1804.08838)Args:seed (int): initial seed used for the jax.random.PRNGKeysubspace_dim (int): number of subspace dimensionsReturns:A `GradientTransformation`."""def init_fn(params: base.Params) -> RandomSubSpaceState:del paramsreturn RandomSubSpaceState(rng_key=jax.random.PRNGKey(seed))def update_fn(grads, state, params=None):del params# split rngkeys = jax.random.split(state.rng_key, 2)# Flatten gradients into 1d arraygrads_flat, grads_treedef = jax.tree_util.tree_flatten(grads)# Choose a random intrinsic subspace and reshape as per gradsintrinsic_space = jax.random.normal(keys[1], (len(grads_flat),))intrinsic_space /= jnp.linalg.norm(intrinsic_space)intrinsic_space = jax.tree_util.tree_unflatten(grads_treedef, intrinsic_space)# project gradients along intrinsic dimensionprojected_grads = jax.tree_util.tree_map(lambda x, y: jax.jit(lambda a, b: a * b)(x, y), grads, intrinsic_space)return projected_grads, RandomSubSpaceState(rng_key=keys[0])return base.GradientTransformation(init_fn, update_fn)
While the code is fairly simple, I'd like to focus on some things here:
- We need to flatten the original gradients in order to obtain the shape of the gradients (in this case the grads_treedef).
- We take a random normal subspace of the gradients and change its shape to be compatible with the gradients using the tree_unflatten utility function.
- In this case, grads are a PyTree object, whereas the intrinsic_space is Traced Jaxpr object. In order to multiply we need to obtain their values which is why jit the element-wise multiplication within tree_map.
📊 The Results
I ran many experiments over various depths and random seeds for training a simple MLP-based image classification model on the MNIST dataset, here, we can see how the accuracy evolves with varying subspace dimensions.
Below, we can see how the model's performance varies with increasing subspace dimension with varying random seeds.
Run set
42
I encourage you to read through the entire codebase in the provided Colab notebook!
🎬 Conclusion
In this article, we read about how there exists an intrinsic dimension of objective landscapes and how we can use Weights & Biases to explore the training process, plus how that can lead to valuable insights.
To see the full suite of W&B features, please check out this short 5-minute guide. If you want more reports covering the math and "from-scratch" code implementations, let us know in the comments down below or on our forum ✨!
Check out these other reports on Fully Connected covering other LLM related topics like Audio Transformers and hyperparameter optimization.
A Brief Introduction to LoRA
This article givens an overview of LoRA (Low-Rank Adaptation) of Large Language Models , using W&B for interactive visualizations. It includes code samples for you to follow.
A guide to large language models (LLMs)
Learn about the history of LLMs, including the groundbreaking GPT series and how they work, and explore developments like human-guided reinforcement learning.
An Introduction to Transformer Networks
This article provides an A-to-Z guide to how Transformer Networks function, and discusses why they outperform neural network models such as LSTM and RNN.
A Gentle Introduction to Retrieval Augmented Generation (RAG)
In this article, we will learn about Retrieval Augmented Generation (RAG) and how it helps pre-trained LLM models to generate more specific, diverse and factual responses.
Tree of Thoughts, Sophia, Goat, QLoRA, and Other ML News
Here's a round-up of the Tree of Thoughts, Second-order Clipped Stochastic Optimization (Sophia), GOod at Arithmetic Tasks( Goat), QLoRA, and other ML news.
Scaling Llama 2 to 32k Tokens With LongLora
The need for LLMs that can digest long content is becoming increasingly more important. Go beyond 4096 tokens with LongLora!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.