Skip to main content

Getting started with Apple MLX

A guide to Apple's new deep learning framework. It is faster than Torch on the M1?
Created on October 13|Last edited on November 1
Machine learning frameworks like PyTorch and TensorFlow are widely used across hardware platforms, including powerful NVIDIA GPUs and Google TPUs, which are ideal for training large models. However, MLX is a newer framework designed specifically for Apple silicon, making it optimized for the architecture found in M1 and M2 chips. While high-performance hardware remains the best choice for large-scale deep learning tasks, MLX provides an alternative for running smaller experiments or prototyping directly on a Mac, which can be useful when a deep learning machine is not readily available.
A young Steve Jobs


Table of contents



What is MLX?

MLX is a machine learning framework designed for Apple hardware, combining a familiar NumPy-like API with added tools for neural networks and optimizers. It supports automatic differentiation and computational graph optimization, making it ideal for quick model prototyping and experimentation on Macs.
In addition to core machine learning functions, MLX leverages Apple’s unified memory architecture to optimize memory usage, allowing for smoother model development without needing high-end hardware. While GPUs and TPUs are optimal for large-scale projects, MLX offers a convenient solution for researchers and developers in the Apple ecosystem who need a powerful, accessible framework for small-to-medium tasks.

MLX training example and a comparison to PyTorch

To illustrate how MLX works, let’s look at a basic training loop and compare it to a similar loop in PyTorch. Both frameworks perform the same fundamental tasks but differ in how these tasks are expressed and handled.
In MLX, a training loop might look like this:
# Initialize model, optimizer, and loss function
# Define a function to compute both loss and gradients
loss_and_grad_fn = value_and_grad(model, loss_fn)

# Training loop
for epoch in range(num_epochs):
for X, y in train_loader:
loss, grads = loss_and_grad_fn(X, y) # Forward + backward pass
optimizer.update(model, grads) # Update model parameters
In this MLX example, the function loss_and_grad_fn(X, y) calculates both the loss and gradients in one step. The optimizer then updates the model parameters with the computed gradients, and the model is switched to evaluation mode after each update. The steps are combined in a way that reduces the need for separate calls to functions for backpropagation and parameter updates.
In contrast, here’s how a similar process might look in PyTorch:
# Initialize model, optimizer, and loss function

for epoch in range(num_epochs):
for X, y in train_loader:
optimizer.zero_grad() # Clear gradients
output = model(X) # Forward pass
loss = loss_fn(output, y) # Compute loss
loss.backward() # Backward pass to compute gradients
optimizer.step() # Update model parameters
In PyTorch, each step is more explicitly handled. The gradients must be manually zeroed out before each forward pass using optimizer.zero_grad(). The forward pass, loss computation, backward pass, and parameter update are all performed in separate steps.
The rest of MLX is actually quite similar to Torch in terms of syntax, and the fundamental structure remains familiar. You still define models, optimizers, and loss functions in a way that closely resembles PyTorch, which makes transitioning between the two frameworks relatively straightforward.

Training on MNIST and comparing the performance with NanoGPT

Now that we've outlined the differences between MLX and PyTorch in a simple training loop, the next step is to look at how MLX performs with more complex tasks. In the upcoming sections, we’ll explore training a model on the MNIST dataset and compare the performance of MLX with NanoGPT, which is an implementation of smaller versions of GPT-2. This will give us a clearer picture of how well MLX handles different workloads and how it performs relative to other frameworks on Apple silicon hardware.
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import wandb
from functools import partial
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Initialize wandb
wandb.init(project="MLX_MNIST")

class MLP(nn.Module):
"""A simple MLP."""

def __init__(self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int):
super().__init__()
layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
self.layers = [
nn.Linear(idim, odim)
for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
]

def __call__(self, x):
for l in self.layers[:-1]:
x = nn.relu(l(x))
return self.layers[-1](x)


def loss_fn(model, X, y):
return nn.losses.cross_entropy(model(X), y, reduction="mean")


def main():
seed = 0
num_layers = 2
hidden_dim = 32
num_classes = 10
batch_size = 256
num_epochs = 10
learning_rate = 1e-1

np.random.seed(seed)

# Data preprocessing and loading with input flattening
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
transforms.Lambda(lambda x: x.view(-1)) # Flatten input tensor
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, optimizer, and loss function
model = MLP(num_layers, 28 * 28, hidden_dim, num_classes)

optimizer = optim.SGD(learning_rate=learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Log the hyperparameters with wandb
wandb.config = {
"num_layers": num_layers,
"hidden_dim": hidden_dim,
"learning_rate": learning_rate,
"batch_size": batch_size,
"num_epochs": num_epochs
}

# Training loop
for epoch in range(num_epochs):
epoch_loss = 0
for X, y in train_loader:
X, y = mx.array(X.numpy()), mx.array(y.numpy())
loss, grads = loss_and_grad_fn(model, X, y)
optimizer.update(model, grads)
epoch_loss += loss.item() # Accumulate loss

# Evaluation
test_images, test_labels = next(iter(test_loader))
test_images, test_labels = mx.array(test_images.numpy()), mx.array(test_labels.numpy())
accuracy = mx.mean(mx.argmax(model(test_images), axis=1) == test_labels)

# Log metrics with wandb
wandb.log({
"epoch": epoch,
"loss": epoch_loss / len(train_loader), # Average training loss
"accuracy": accuracy.item()
})

print(f"Epoch {epoch}: Test accuracy {accuracy.item():.3f}, Avg Loss: {epoch_loss / len(train_loader):.3f}")


if __name__ == "__main__":
mx.set_default_device(mx.gpu)
main()

This code demonstrates how to use MLX, a framework optimized for Apple silicon, to train a simple multi-layer perceptron (MLP) on the MNIST dataset.
The MLP model is defined using the MLX library, where the neural network consists of a set number of layers, each implemented with nn.Linear to represent fully connected layers. ReLU activations are applied between layers, except for the output layer, which returns raw logits. The network processes the flattened 28x28 pixel input images from the MNIST dataset.
Data loading is handled using PyTorch's DataLoader, which facilitates efficient batching and shuffling of the MNIST dataset. Preprocessing includes normalizing and flattening the images, ensuring they are in the appropriate format for the MLP. The loss function used is cross-entropy, a standard for classification tasks, and the optimizer is stochastic gradient descent, both provided by the MLX library. The value_and_grad function calculates the loss and gradients in one step, streamlining the training process. The model parameters are updated using these gradients at each iteration.
Weights & Biases is integrated into the script to track and log key metrics like the training loss and test accuracy in real-time. This provides an easy way to visualize the model's performance and tune hyperparameters based on logged metrics. Here are the results for the MNIST run!

Run: divine-spaceship-6
1


Porting MLX to NanoGPT

After working with MLX on smaller models, let's port NanoGPT over to MLX and run some experiments to see which framework offered the best performance. NanoGPT is a lightweight GPT model, and it seemed like a good candidate for testing MLX on Apple silicon. To simplify the experiments and ensure a fair comparison, we'll disable certain advanced techniques such as gradient scaling and gradient accumulation, focusing on a more basic configuration to ensure the results are directly comparable.

Experiment configuration and hyperparameters

One of the key changes, made was in the optimizer configuration. The original NanoGPT setup involved grouping parameters based on their dimensionality and selectively applying weight decay. For instance, weight tensors in layers that perform matrix multiplications typically have weight decay applied, while biases and layer normalization parameters do not. However, we want to minimize complexity in the experiment to make it easier to compare the performance of MLX and PyTorch in an "apples-to-apples" fashion. As a result, we'll simplify the optimizer configuration by applying the same weight decay and learning rate to all parameters. This ensures that all parameters are treated uniformly, reducing the chances of minor implementation differences skewing the performance results.
By using this simpler configuration, we are able to streamline the process without affecting the core functionality of the model. This choice allows us to focus on comparing the frameworks' raw performance without worrying about optimizing specific parts of the model differently across the two implementations. The goal is to isolate the performance differences of the frameworks, rather than add complexities that could make the results harder to interpret.
In addition to simplifying the optimizer, we'll leverage PyTorch’s torch.compile feature on the M1 Pro. Introduced in PyTorch 2.0, torch.compile optimizes models by applying just-in-time compilation, converting the model into a more efficient format that reduces Python overhead and accelerates execution. Setting up torch.compile on the M1 Pro required some effort, is worth the effort.
This allows NanoGPT to make full use of the hardware capabilities of Apple silicon, providing an optimized comparison point against MLX.
We're also utilizing the new Metal Performance Shaders (MPS) backend with PyTorch to accelerate training on the GPU of the Macbook. The MPS backend is part of Apple's Metal programming framework, enabling high-performance training on Apple silicon GPUs by mapping machine learning computational graphs onto the Metal Performance Shaders Graph framework. This significantly boosts training performance by leveraging optimized, tuned kernels specifically designed for Apple hardware.
The MPS backend introduces a new device in PyTorch, allowing operations to run directly on the GPU in macOS environments. By moving tensors and models to the MPS device, you can take advantage of the full power of the GPU on Apple silicon, offering an alternative to more traditional GPU setups like NVIDIA. This can be particularly beneficial when working within the Apple ecosystem, as it enables GPU acceleration without external hardware.
For the actual experiments, we'll use a batch size of 8 and measured the time it took to complete 100 iterations on the Shakespeare_char dataset. This dataset, composed of character-level text, was ideal for testing the performance of NanoGPT on a real-world task, as it mimics the type of work GPT models are often tasked with. By keeping the batch size consistent and running the same number of iterations, we are able to compare the performance of both frameworks under identical conditions.
Most of the training script is left unchanged, with the main differences related to setting up the correct compute backend (for using the metal performance shaders etc.). You can find a nice implementation of the GPT-2 model in Apple's MLX examples repo.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np

from base import BaseModelArgs, create_attention_mask


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_ctx: int
n_embd: int
n_head: int
n_layer: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.n_head


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"

self.n_embd = args.n_embd
self.n_head = args.n_head
self.head_dim = self.n_embd // self.n_head

self.scale = self.head_dim**-0.5

self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape

qkv = self.c_attn(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)

if cache is not None:
keys, values = cache.update_and_fetch(keys, values)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)


class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.n_embd = args.n_embd
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)

def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu_approx(self.c_fc(x)))


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.n_head = args.n_head
self.n_embd = args.n_embd
self.layer_norm_epsilon = args.layer_norm_epsilon
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(
self.n_embd,
eps=self.layer_norm_epsilon,
)
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out


class GPT2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.n_positions = args.n_positions
self.vocab_size = args.vocab_size
self.n_layer = args.n_layer
self.layer_norm_epsilon = args.layer_norm_epsilon
assert self.vocab_size > 0
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)

def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape

hidden_states = self.wte(inputs)

mask = None
if hidden_states.shape[1] > 1:

position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)

mask = create_attention_mask(hidden_states, cache)

if cache is None:
cache = [None] * len(self.h)

for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)

return self.ln_f(hidden_states)


class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPT2Model(args)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.wte.as_linear(out)
return out

def sanitize(self, weights):
new_weights = {}
for i in range(self.args.n_layer):
if f"h.{i}.attn.bias" in weights:
del weights[f"h.{i}.attn.bias"]
if f"h.{i}.attn.c_attn.weight" in weights:
weights[f"h.{i}.attn.c_attn.weight"] = weights[
f"h.{i}.attn.c_attn.weight"
].transpose(1, 0)
if f"h.{i}.attn.c_proj.weight" in weights:
weights[f"h.{i}.attn.c_proj.weight"] = weights[
f"h.{i}.attn.c_proj.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_fc.weight" in weights:
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
f"h.{i}.mlp.c_fc.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_proj.weight" in weights:
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
f"h.{i}.mlp.c_proj.weight"
].transpose(1, 0)
for weight in weights:
if not weight.startswith("model."):
new_weights[f"model.{weight}"] = weights[weight]
else:
new_weights[weight] = weights[weight]
return new_weights

@property
def layers(self):
return self.model.h
If you're interested in the exact implementation details, including how torch.compile was configured on the M1 Pro or any other specifics of the NanoGPT port to MLX, the full code is available in the project repository.

The results

The results were quite similar, with MLX running a bit faster, finishing in 2 minutes and 33 seconds, compared to Torch’s 2 minutes and 45 seconds. Note that MLX does offer bfloat16, however there was no notable performance boost at the same batch size. However, using bfloat16 does allow for increasing batch size due to less memory usage, which could increase overall throughput, so this is definitely worth utilizing if you are working with MLX.
Here are the charts for both runs:

Run set
5

I decided to run another run with Torch using just the CPU, as opposed to the Metal Performance Shader backend. The results were quite a bit different! Using the CPU backend with Torch took over 5 minutes, which is significantly slower!
💡

Overall

The implementation here demonstrates how MLX can be a solid choice for local experiments, especially on Apple silicon. MLX simplifies the workflow and takes advantage of the hardware it’s optimized for, allowing for effective experimentation without needing a dedicated deep learning machine.
However, developers should approach MLX with caution if there's a high likelihood that their experiments will need to scale up or migrate to environments using GPUs. MLX is optimized for Apple silicon, and transitioning to more conventional GPU-heavy setups (like those utilizing NVIDIA GPUs or cloud-based GPU instances) might require significant code refactoring. The syntax and optimizations used in MLX differ from traditional frameworks like PyTorch or TensorFlow, which are built to seamlessly leverage GPU acceleration. As a result, what works well locally on MLX may not directly translate to GPU-based environments, and modifications to the codebase might be necessary to accommodate these changes.
For more details, you can find the project repo here.



Iterate on AI agents and models faster. Try Weights & Biases today.