Table of Contents
Fetching ...

BatchTopK Sparse Autoencoders

Bart Bussmann, Patrick Leask, Neel Nanda

TL;DR

Problem: TopK SAEs enforce a fixed number of active latents per sample, which can limit reconstruction fidelity. Approach: BatchTopK SAEs relax the constraint to batch-level, allowing a variable number of active latents per sample and using a global threshold at inference to remove batch dependence. Contributions: empirical demonstrations on GPT-2 Small and Gemma 2 2B show BatchTopK improves NMSE and reduces CE degradation relative to TopK and is competitive with JumpReLU SAEs, with the practical advantage of direct sparsity specification. Impact: the method reveals how batch-aware sparsity and small activation-function adjustments can enhance reconstruction and interpretability without sacrificing average sparsity; code is provided.

Abstract

Sparse autoencoders (SAEs) have emerged as a powerful tool for interpreting language model activations by decomposing them into sparse, interpretable features. A popular approach is the TopK SAE, that uses a fixed number of the most active latents per sample to reconstruct the model activations. We introduce BatchTopK SAEs, a training method that improves upon TopK SAEs by relaxing the top-k constraint to the batch-level, allowing for a variable number of latents to be active per sample. As a result, BatchTopK adaptively allocates more or fewer latents depending on the sample, improving reconstruction without sacrificing average sparsity. We show that BatchTopK SAEs consistently outperform TopK SAEs in reconstructing activations from GPT-2 Small and Gemma 2 2B, and achieve comparable performance to state-of-the-art JumpReLU SAEs. However, an advantage of BatchTopK is that the average number of latents can be directly specified, rather than approximately tuned through a costly hyperparameter sweep. We provide code for training and evaluating BatchTopK SAEs at https://github.com/bartbussmann/BatchTopK

BatchTopK Sparse Autoencoders

TL;DR

Problem: TopK SAEs enforce a fixed number of active latents per sample, which can limit reconstruction fidelity. Approach: BatchTopK SAEs relax the constraint to batch-level, allowing a variable number of active latents per sample and using a global threshold at inference to remove batch dependence. Contributions: empirical demonstrations on GPT-2 Small and Gemma 2 2B show BatchTopK improves NMSE and reduces CE degradation relative to TopK and is competitive with JumpReLU SAEs, with the practical advantage of direct sparsity specification. Impact: the method reveals how batch-aware sparsity and small activation-function adjustments can enhance reconstruction and interpretability without sacrificing average sparsity; code is provided.

Abstract

Sparse autoencoders (SAEs) have emerged as a powerful tool for interpreting language model activations by decomposing them into sparse, interpretable features. A popular approach is the TopK SAE, that uses a fixed number of the most active latents per sample to reconstruct the model activations. We introduce BatchTopK SAEs, a training method that improves upon TopK SAEs by relaxing the top-k constraint to the batch-level, allowing for a variable number of latents to be active per sample. As a result, BatchTopK adaptively allocates more or fewer latents depending on the sample, improving reconstruction without sacrificing average sparsity. We show that BatchTopK SAEs consistently outperform TopK SAEs in reconstructing activations from GPT-2 Small and Gemma 2 2B, and achieve comparable performance to state-of-the-art JumpReLU SAEs. However, an advantage of BatchTopK is that the average number of latents can be directly specified, rather than approximately tuned through a costly hyperparameter sweep. We provide code for training and evaluating BatchTopK SAEs at https://github.com/bartbussmann/BatchTopK

Paper Structure

This paper contains 7 sections, 6 equations, 3 figures.

Figures (3)

  • Figure 1: On GPT-2 Small activations, BatchTopK largely achieves better NMSE and CE than standard TopK across different dictionary sizes, for a fixed number of active latents of 32 (Left). JumpReLU SAEs are omitted from this comparison as their L0 cannot be fixed to a value. For fixed dictionary size (12288) and varying levels of k, BatchTopK outperforms TopK and JumpReLU SAES, both in terms of NMSE and CE (Right).
  • Figure 2: On Gemma 2 2B activations, BatchTopK outperforms TopK SAEs across different values of k. Although BatchTopK has a better reconstruction performance (left), it only outperforms JumpReLU in terms of downstream CE degradation in the setting where k=16 (right).
  • Figure 3: Distribution of the number of active latents per sample for a BatchTopK model. The distribution shows that some samples use very few latents, while others use many, illustrating the flexibility that BatchTopK provides.