Comparison of Variational Methods in SAEs
hjfvdlksa;
Created on June 14|Last edited on June 17
Comment
Introduction
Sparse Autoencoders (SAEs) learn overcomplete dictionaries of interpretable features from neural network activations by reconstructing the original activations while enforcing sparsity through various means of regularization. This approach has proven effective for mechanistic interpretability, revealing monosemantic features that correspond to human-understandable concepts within transformer representations.
Experimental Setup
Our experimental framework provides controlled comparisons between standard SAE variants and their variational counterparts across a unified architecture. We evaluate all models on the gelu-1l transformer at layer 0 (blocks.0.mlp.hook_post), using a dictionary size of 4× the MLP dimension (8,192 features). Training occurs over 60k steps with bfloat16 precision, using consistent buffer configurations (2.5k contexts of length 128) and identical evaluation metrics including fraction variance explained, L0 sparsity, and loss recovery. Key hyperparameters—learning rate and method-specific parameters (e.g., KL coefficient for variational methods, top-k values)—are systematically varied while maintaining architectural consistency. Variational SAE experiments additionally incorporate KL annealing schedules to prevent posterior collapse, with learned variance options available for enhanced expressivity.
Other runs available on this profile under projects.
Standard Sparse Autoencoder (SAE)
The standard SAE learns a sparse representation through deterministic encoding with explicit sparsity regularization:
Encoder:
Decoder:
Loss Function:
where λ controls the sparsity-reconstruction tradeoff and the L1 penalty enforces sparse activations.
Variational Sparse Autoencoder (VSAE)
The VSAE replaces explicit sparsity regularization with a probabilistic framework using a Gaussian prior:
Encoder (Mean):
Encoder (Variance, when learned):
Reparameterization:
Decoder:
Loss Function:
where
The fundamental distinction lies in sparsity mechanisms: standard SAEs achieve sparsity through deterministic ReLU gating and explicit L1 regularization, while VSAEs achieve sparsity through stochastic sampling from a learned posterior regularized toward a sparse prior. This allows VSAEs to model uncertainty in feature activations and potentially capture richer representational structure through the probabilistic latent space.
When we say an SAE is variational, we mean to say that the reparameterization trick was added using the isotropic gaussian prior as well as a KL loss term to encourage the learned distribution towards the standard normal prior.
Algorithms Used
standard, isotropic gaussian, laplace, and spike/slab priors failed to enable any sparsity in experiments. Active features were pinned to 8192 (dictionary size) despite massive changes to all hyperparameters. This could likely be fixed with some thought and code changes, but the sheer ease of use from the topk methods and their great performance makes me want to prioritize that over attempting to get meaningful performance out of the more standard implementations. Even the standard SAE was very difficult to achieve a target sparsity and be confident in its results. The training runs for these methods can be found in their own respective project, but I left out the graphs for the sake of brevity.
p-annealing and gdm were found to be unstable during training and were also omitted. similarly, the results for these can be found in their respective projects.
Fraction of Variance Explained is a useful way to gauge model performance, defined as:
l2 loss is equivalent to MSE loss here, defined as:
topk
Topk functions by taking the top activating k features, and setting the rest of the SAE features to 0. From there, it is similar to a standard SAE.
nonvariational:
Run set
14
variational:
Run set
34
We will focus on the purple nonvariational run, as that has an l0 of 64, which is very similar to the l0 of 65 for the variational case. Dividing or multiplying the target l0 by 2 results in ~8% change in performance. We will see a similar effect for the rest of the nonvariational methods: a target l0 of 32, 64, and 128.
The Variational methods will all have a target l0 of 65 (instead of 64 because of artifacts in the code, performance differences will be negligible). In the variational case, we change the KL divergence weighting coefficient from (30, 1, 0.01) with a few outliers. The KL value can be seen in the name of each run. A lower KL term means that performance will go up, as it reduces loss and relaxes the variational affect as the KL term tends to 0.
Not only do we see great performance across the board, but a kl term of 0.01 actually outperforms the nonvariational method!
batchtopk
Batchtopk functions very similarly to topk, but instead of applying topk to each sample in the batch, the topk is applied across the batch: one activation might have more than k activations, but across the batch you are guaranteed an average of k activations per sample.
nonvariational:
Run set
5
variational:
Run set
4
Batchtopk seems to perform about as well as topk, although all of the values appear to have been slightly lower. As a quick check, hovering with the mouse gives a ~0.02 variability between logging steps for variance explained, but the final values seem to be (0.71, 0.72) topk, (0.70, 0.67) batchtopk for (standard, variational).
jumprelu
nonvariational:
Run set
4
variational:
Run set
4
Jumprelu was quite difficult to train, as we could not simply give a target L0. We had to find a good balance within the hyperparamers and find out how long to train. Looking to where the L0 is ~60, we see that the var explained is sitting just under 0.7. The L2 loss was relatively high at 0.5 or so. The variational techniques underperformed. We are supposed to see Jumprelu be quite competitive according to the topk method according to their paper, so there may still be some performance left on the table, but the difficulty in training makes this difficult.
gated anneal
nonvariational:
Run set
3
variational:
Run set
3
Performance here was exceptional around the 11th and 12th step (where L0 was ~60), giving an mse loss of ~0.18 for the nonvariational and ~0.23 for the variational methods when L0 was similar. The variational results are quite good. Jumprelu is supposed to be superior to this technique. I wonder if there is a bug in the code, or if hyperparameters and training time need additional tuning?
matryoshka (with batchtopk)
nonvariational:
Run set
3
variational:
Run set
5
Here we see the best results so far.
Conclusion
Topk, gated anneal, and Matryoshka were the best performers for both the variational and non-variational methods. However, the gated annealing SAE was quite difficult to train and does not offer the same benefits for feature splitting/absorption as the Matryoshka SAE, and the matryoshka method already has the batchtopk method baked in. So, the Matryoshka method (somewhat subjectively, as we have not yet measured any feature splitting/absorption) seems to be the best performer out of the bunch!
Questions:
What is the effect of feature splitting?
what is the effect of the larger/smaller KL parameters on the latent space?
1l transformers have been shown to be "an ensemble of bigram and “skip-trigram” (sequences of the form "A… B C") models". How would the results be affected by using gpt2?
How do these perform out of distribution?
Add a comment