Getting started with Apple MLX
A guide to Apple's new deep learning framework. It is faster than Torch on the M1?
Created on October 13|Last edited on November 1
Comment
Machine learning frameworks like PyTorch and TensorFlow are widely used across hardware platforms, including powerful NVIDIA GPUs and Google TPUs, which are ideal for training large models. However, MLX is a newer framework designed specifically for Apple silicon, making it optimized for the architecture found in M1 and M2 chips. While high-performance hardware remains the best choice for large-scale deep learning tasks, MLX provides an alternative for running smaller experiments or prototyping directly on a Mac, which can be useful when a deep learning machine is not readily available.

A young Steve Jobs
Table of contents
What is MLX?MLX training example and a comparison to PyTorchTraining on MNIST and comparing the performance with NanoGPTPorting MLX to NanoGPT Experiment configuration and hyperparameters The results Overall Related articles
What is MLX?
MLX is a machine learning framework designed for Apple hardware, combining a familiar NumPy-like API with added tools for neural networks and optimizers. It supports automatic differentiation and computational graph optimization, making it ideal for quick model prototyping and experimentation on Macs.
In addition to core machine learning functions, MLX leverages Apple’s unified memory architecture to optimize memory usage, allowing for smoother model development without needing high-end hardware. While GPUs and TPUs are optimal for large-scale projects, MLX offers a convenient solution for researchers and developers in the Apple ecosystem who need a powerful, accessible framework for small-to-medium tasks.
MLX training example and a comparison to PyTorch
To illustrate how MLX works, let’s look at a basic training loop and compare it to a similar loop in PyTorch. Both frameworks perform the same fundamental tasks but differ in how these tasks are expressed and handled.
In MLX, a training loop might look like this:
# Initialize model, optimizer, and loss function# Define a function to compute both loss and gradientsloss_and_grad_fn = value_and_grad(model, loss_fn)# Training loopfor epoch in range(num_epochs):for X, y in train_loader:loss, grads = loss_and_grad_fn(X, y) # Forward + backward passoptimizer.update(model, grads) # Update model parameters
In this MLX example, the function loss_and_grad_fn(X, y) calculates both the loss and gradients in one step. The optimizer then updates the model parameters with the computed gradients, and the model is switched to evaluation mode after each update. The steps are combined in a way that reduces the need for separate calls to functions for backpropagation and parameter updates.
In contrast, here’s how a similar process might look in PyTorch:
# Initialize model, optimizer, and loss functionfor epoch in range(num_epochs):for X, y in train_loader:optimizer.zero_grad() # Clear gradientsoutput = model(X) # Forward passloss = loss_fn(output, y) # Compute lossloss.backward() # Backward pass to compute gradientsoptimizer.step() # Update model parameters
In PyTorch, each step is more explicitly handled. The gradients must be manually zeroed out before each forward pass using optimizer.zero_grad(). The forward pass, loss computation, backward pass, and parameter update are all performed in separate steps.
The rest of MLX is actually quite similar to Torch in terms of syntax, and the fundamental structure remains familiar. You still define models, optimizers, and loss functions in a way that closely resembles PyTorch, which makes transitioning between the two frameworks relatively straightforward.
Training on MNIST and comparing the performance with NanoGPT
Now that we've outlined the differences between MLX and PyTorch in a simple training loop, the next step is to look at how MLX performs with more complex tasks. In the upcoming sections, we’ll explore training a model on the MNIST dataset and compare the performance of MLX with NanoGPT, which is an implementation of smaller versions of GPT-2. This will give us a clearer picture of how well MLX handles different workloads and how it performs relative to other frameworks on Apple silicon hardware.
import mlx.core as mximport mlx.nn as nnimport mlx.optimizers as optimimport numpy as npimport wandbfrom functools import partialfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transforms# Initialize wandbwandb.init(project="MLX_MNIST")class MLP(nn.Module):"""A simple MLP."""def __init__(self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int):super().__init__()layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]self.layers = [nn.Linear(idim, odim)for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])]def __call__(self, x):for l in self.layers[:-1]:x = nn.relu(l(x))return self.layers[-1](x)def loss_fn(model, X, y):return nn.losses.cross_entropy(model(X), y, reduction="mean")def main():seed = 0num_layers = 2hidden_dim = 32num_classes = 10batch_size = 256num_epochs = 10learning_rate = 1e-1np.random.seed(seed)# Data preprocessing and loading with input flatteningtransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),transforms.Lambda(lambda x: x.view(-1)) # Flatten input tensor])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# Initialize model, optimizer, and loss functionmodel = MLP(num_layers, 28 * 28, hidden_dim, num_classes)optimizer = optim.SGD(learning_rate=learning_rate)loss_and_grad_fn = nn.value_and_grad(model, loss_fn)# Log the hyperparameters with wandbwandb.config = {"num_layers": num_layers,"hidden_dim": hidden_dim,"learning_rate": learning_rate,"batch_size": batch_size,"num_epochs": num_epochs}# Training loopfor epoch in range(num_epochs):epoch_loss = 0for X, y in train_loader:X, y = mx.array(X.numpy()), mx.array(y.numpy())loss, grads = loss_and_grad_fn(model, X, y)optimizer.update(model, grads)epoch_loss += loss.item() # Accumulate loss# Evaluationtest_images, test_labels = next(iter(test_loader))test_images, test_labels = mx.array(test_images.numpy()), mx.array(test_labels.numpy())accuracy = mx.mean(mx.argmax(model(test_images), axis=1) == test_labels)# Log metrics with wandbwandb.log({"epoch": epoch,"loss": epoch_loss / len(train_loader), # Average training loss"accuracy": accuracy.item()})print(f"Epoch {epoch}: Test accuracy {accuracy.item():.3f}, Avg Loss: {epoch_loss / len(train_loader):.3f}")if __name__ == "__main__":mx.set_default_device(mx.gpu)main()
This code demonstrates how to use MLX, a framework optimized for Apple silicon, to train a simple multi-layer perceptron (MLP) on the MNIST dataset.
The MLP model is defined using the MLX library, where the neural network consists of a set number of layers, each implemented with nn.Linear to represent fully connected layers. ReLU activations are applied between layers, except for the output layer, which returns raw logits. The network processes the flattened 28x28 pixel input images from the MNIST dataset.
Data loading is handled using PyTorch's DataLoader, which facilitates efficient batching and shuffling of the MNIST dataset. Preprocessing includes normalizing and flattening the images, ensuring they are in the appropriate format for the MLP. The loss function used is cross-entropy, a standard for classification tasks, and the optimizer is stochastic gradient descent, both provided by the MLX library. The value_and_grad function calculates the loss and gradients in one step, streamlining the training process. The model parameters are updated using these gradients at each iteration.
Weights & Biases is integrated into the script to track and log key metrics like the training loss and test accuracy in real-time. This provides an easy way to visualize the model's performance and tune hyperparameters based on logged metrics. Here are the results for the MNIST run!
Run: divine-spaceship-6
1
Porting MLX to NanoGPT
After working with MLX on smaller models, let's port NanoGPT over to MLX and run some experiments to see which framework offered the best performance. NanoGPT is a lightweight GPT model, and it seemed like a good candidate for testing MLX on Apple silicon. To simplify the experiments and ensure a fair comparison, we'll disable certain advanced techniques such as gradient scaling and gradient accumulation, focusing on a more basic configuration to ensure the results are directly comparable.
Experiment configuration and hyperparameters
One of the key changes, made was in the optimizer configuration. The original NanoGPT setup involved grouping parameters based on their dimensionality and selectively applying weight decay. For instance, weight tensors in layers that perform matrix multiplications typically have weight decay applied, while biases and layer normalization parameters do not. However, we want to minimize complexity in the experiment to make it easier to compare the performance of MLX and PyTorch in an "apples-to-apples" fashion. As a result, we'll simplify the optimizer configuration by applying the same weight decay and learning rate to all parameters. This ensures that all parameters are treated uniformly, reducing the chances of minor implementation differences skewing the performance results.
By using this simpler configuration, we are able to streamline the process without affecting the core functionality of the model. This choice allows us to focus on comparing the frameworks' raw performance without worrying about optimizing specific parts of the model differently across the two implementations. The goal is to isolate the performance differences of the frameworks, rather than add complexities that could make the results harder to interpret.
In addition to simplifying the optimizer, we'll leverage PyTorch’s torch.compile feature on the M1 Pro. Introduced in PyTorch 2.0, torch.compile optimizes models by applying just-in-time compilation, converting the model into a more efficient format that reduces Python overhead and accelerates execution. Setting up torch.compile on the M1 Pro required some effort, is worth the effort.
This allows NanoGPT to make full use of the hardware capabilities of Apple silicon, providing an optimized comparison point against MLX.
We're also utilizing the new Metal Performance Shaders (MPS) backend with PyTorch to accelerate training on the GPU of the Macbook. The MPS backend is part of Apple's Metal programming framework, enabling high-performance training on Apple silicon GPUs by mapping machine learning computational graphs onto the Metal Performance Shaders Graph framework. This significantly boosts training performance by leveraging optimized, tuned kernels specifically designed for Apple hardware.
The MPS backend introduces a new device in PyTorch, allowing operations to run directly on the GPU in macOS environments. By moving tensors and models to the MPS device, you can take advantage of the full power of the GPU on Apple silicon, offering an alternative to more traditional GPU setups like NVIDIA. This can be particularly beneficial when working within the Apple ecosystem, as it enables GPU acceleration without external hardware.
For the actual experiments, we'll use a batch size of 8 and measured the time it took to complete 100 iterations on the Shakespeare_char dataset. This dataset, composed of character-level text, was ideal for testing the performance of NanoGPT on a real-world task, as it mimics the type of work GPT models are often tasked with. By keeping the batch size consistent and running the same number of iterations, we are able to compare the performance of both frameworks under identical conditions.
Most of the training script is left unchanged, with the main differences related to setting up the correct compute backend (for using the metal performance shaders etc.). You can find a nice implementation of the GPT-2 model in Apple's MLX examples repo.
from dataclasses import dataclassfrom typing import Any, Dict, Optional, Tuple, Unionimport mlx.core as mximport mlx.nn as nnimport numpy as npfrom base import BaseModelArgs, create_attention_mask@dataclassclass ModelArgs(BaseModelArgs):model_type: strn_ctx: intn_embd: intn_head: intn_layer: intn_positions: intlayer_norm_epsilon: floatvocab_size: intnum_key_value_heads: int = Nonedef __post_init__(self):if self.num_key_value_heads is None:self.num_key_value_heads = self.n_headclass Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"self.n_embd = args.n_embdself.n_head = args.n_headself.head_dim = self.n_embd // self.n_headself.scale = self.head_dim**-0.5self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)def __call__(self,x: mx.array,mask: Optional[mx.array] = None,cache: Optional[Any] = None,) -> mx.array:B, L, D = x.shapeqkv = self.c_attn(x)queries, keys, values = mx.split(qkv, 3, axis=-1)# Prepare the queries, keys and values for the attention computationqueries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)if cache is not None:keys, values = cache.update_and_fetch(keys, values)output = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=self.scale, mask=mask)output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)return self.c_proj(output)class MLP(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_embd = args.n_embdself.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)def __call__(self, x) -> mx.array:return self.c_proj(nn.gelu_approx(self.c_fc(x)))class TransformerBlock(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_head = args.n_headself.n_embd = args.n_embdself.layer_norm_epsilon = args.layer_norm_epsilonself.attn = Attention(args)self.mlp = MLP(args)self.ln_1 = nn.LayerNorm(self.n_embd,eps=self.layer_norm_epsilon,)self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)def __call__(self,x: mx.array,mask: Optional[mx.array] = None,cache: Optional[Any] = None,) -> mx.array:r = self.attn(self.ln_1(x), mask, cache)h = x + rr = self.mlp(self.ln_2(h))out = h + rreturn outclass GPT2Model(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.n_embd = args.n_embdself.n_positions = args.n_positionsself.vocab_size = args.vocab_sizeself.n_layer = args.n_layerself.layer_norm_epsilon = args.layer_norm_epsilonassert self.vocab_size > 0self.wte = nn.Embedding(self.vocab_size, self.n_embd)self.wpe = nn.Embedding(self.n_positions, self.n_embd)self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)def __call__(self,inputs: mx.array,cache=None,):_, L = inputs.shapehidden_states = self.wte(inputs)mask = Noneif hidden_states.shape[1] > 1:position_ids = mx.array(np.arange(L))hidden_states += self.wpe(position_ids)mask = create_attention_mask(hidden_states, cache)if cache is None:cache = [None] * len(self.h)for layer, c in zip(self.h, cache):hidden_states = layer(hidden_states, mask, cache=c)return self.ln_f(hidden_states)class Model(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.args = argsself.model_type = args.model_typeself.model = GPT2Model(args)def __call__(self,inputs: mx.array,cache=None,):out = self.model(inputs, cache)out = self.model.wte.as_linear(out)return outdef sanitize(self, weights):new_weights = {}for i in range(self.args.n_layer):if f"h.{i}.attn.bias" in weights:del weights[f"h.{i}.attn.bias"]if f"h.{i}.attn.c_attn.weight" in weights:weights[f"h.{i}.attn.c_attn.weight"] = weights[f"h.{i}.attn.c_attn.weight"].transpose(1, 0)if f"h.{i}.attn.c_proj.weight" in weights:weights[f"h.{i}.attn.c_proj.weight"] = weights[f"h.{i}.attn.c_proj.weight"].transpose(1, 0)if f"h.{i}.mlp.c_fc.weight" in weights:weights[f"h.{i}.mlp.c_fc.weight"] = weights[f"h.{i}.mlp.c_fc.weight"].transpose(1, 0)if f"h.{i}.mlp.c_proj.weight" in weights:weights[f"h.{i}.mlp.c_proj.weight"] = weights[f"h.{i}.mlp.c_proj.weight"].transpose(1, 0)for weight in weights:if not weight.startswith("model."):new_weights[f"model.{weight}"] = weights[weight]else:new_weights[weight] = weights[weight]return new_weights@propertydef layers(self):return self.model.h
If you're interested in the exact implementation details, including how torch.compile was configured on the M1 Pro or any other specifics of the NanoGPT port to MLX, the full code is available in the project repository.
The results
The results were quite similar, with MLX running a bit faster, finishing in 2 minutes and 33 seconds, compared to Torch’s 2 minutes and 45 seconds. Note that MLX does offer bfloat16, however there was no notable performance boost at the same batch size. However, using bfloat16 does allow for increasing batch size due to less memory usage, which could increase overall throughput, so this is definitely worth utilizing if you are working with MLX.
Here are the charts for both runs:
Run set
5
I decided to run another run with Torch using just the CPU, as opposed to the Metal Performance Shader backend. The results were quite a bit different! Using the CPU backend with Torch took over 5 minutes, which is significantly slower!
💡
Overall
The implementation here demonstrates how MLX can be a solid choice for local experiments, especially on Apple silicon. MLX simplifies the workflow and takes advantage of the hardware it’s optimized for, allowing for effective experimentation without needing a dedicated deep learning machine.
However, developers should approach MLX with caution if there's a high likelihood that their experiments will need to scale up or migrate to environments using GPUs. MLX is optimized for Apple silicon, and transitioning to more conventional GPU-heavy setups (like those utilizing NVIDIA GPUs or cloud-based GPU instances) might require significant code refactoring. The syntax and optimizations used in MLX differ from traditional frameworks like PyTorch or TensorFlow, which are built to seamlessly leverage GPU acceleration. As a result, what works well locally on MLX may not directly translate to GPU-based environments, and modifications to the codebase might be necessary to accommodate these changes.
Related articles
How to Run Mistral-7B on an M1 Mac With Ollama
Ever wanted to run Mistral 7B on your Macbook? In this tutorial I show you how!
A Guide to DeepSpeed Zero With the HuggingFace Trainer
A guide for making the most out of your GPU's!
Grokking: Improved generalization through over-overfitting
One of the most mysterious phenomena in deep learning; Grokking is the tendency of neural networks to improve generalization by sustained overfitting.
Training a KANFormer: KAN's Are All You Need?
We will dive into a new experimental architecture, replacing the MLP layers in transformers with KAN layers!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.