BatchTopK Sparse Autoencoders
Bart Bussmann, Patrick Leask, Neel Nanda
TL;DR
Problem: TopK SAEs enforce a fixed number of active latents per sample, which can limit reconstruction fidelity. Approach: BatchTopK SAEs relax the constraint to batch-level, allowing a variable number of active latents per sample and using a global threshold at inference to remove batch dependence. Contributions: empirical demonstrations on GPT-2 Small and Gemma 2 2B show BatchTopK improves NMSE and reduces CE degradation relative to TopK and is competitive with JumpReLU SAEs, with the practical advantage of direct sparsity specification. Impact: the method reveals how batch-aware sparsity and small activation-function adjustments can enhance reconstruction and interpretability without sacrificing average sparsity; code is provided.
Abstract
Sparse autoencoders (SAEs) have emerged as a powerful tool for interpreting language model activations by decomposing them into sparse, interpretable features. A popular approach is the TopK SAE, that uses a fixed number of the most active latents per sample to reconstruct the model activations. We introduce BatchTopK SAEs, a training method that improves upon TopK SAEs by relaxing the top-k constraint to the batch-level, allowing for a variable number of latents to be active per sample. As a result, BatchTopK adaptively allocates more or fewer latents depending on the sample, improving reconstruction without sacrificing average sparsity. We show that BatchTopK SAEs consistently outperform TopK SAEs in reconstructing activations from GPT-2 Small and Gemma 2 2B, and achieve comparable performance to state-of-the-art JumpReLU SAEs. However, an advantage of BatchTopK is that the average number of latents can be directly specified, rather than approximately tuned through a costly hyperparameter sweep. We provide code for training and evaluating BatchTopK SAEs at https://github.com/bartbussmann/BatchTopK
