Table of Contents
Fetching ...

Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders

Senthooran Rajamanoharan, Tom Lieberum, Nicolas Sonnerat, Arthur Conmy, Vikrant Varma, János Kramár, Neel Nanda

TL;DR

This paper tackles the sparsity-versus-fidelity trade-off in sparse autoencoders applied to language model activations. It introduces JumpReLU SAEs, a simple modification of vanilla SAEs that uses a per-feature threshold with JumpReLU activation and trains via straight-through-estimators to optimize an $L_0$ sparsity objective. Empirically, JumpReLU SAEs achieve state-of-the-art reconstruction fidelity at fixed sparsity on Gemma 2 9B activations and maintain interpretability comparable to existing approaches like Gated and TopK SAEs. The work demonstrates efficient training, broad evaluation across layers and sites, and provides a principled approach to training with discontinuous activations that could generalize to other discontinuous loss functions and model families.

Abstract

Sparse autoencoders (SAEs) are a promising unsupervised approach for identifying causally relevant and interpretable linear features in a language model's (LM) activations. To be useful for downstream tasks, SAEs need to decompose LM activations faithfully; yet to be interpretable the decomposition must be sparse -- two objectives that are in tension. In this paper, we introduce JumpReLU SAEs, which achieve state-of-the-art reconstruction fidelity at a given sparsity level on Gemma 2 9B activations, compared to other recent advances such as Gated and TopK SAEs. We also show that this improvement does not come at the cost of interpretability through manual and automated interpretability studies. JumpReLU SAEs are a simple modification of vanilla (ReLU) SAEs -- where we replace the ReLU with a discontinuous JumpReLU activation function -- and are similarly efficient to train and run. By utilising straight-through-estimators (STEs) in a principled manner, we show how it is possible to train JumpReLU SAEs effectively despite the discontinuous JumpReLU function introduced in the SAE's forward pass. Similarly, we use STEs to directly train L0 to be sparse, instead of training on proxies such as L1, avoiding problems like shrinkage.

Jumping Ahead: Improving Reconstruction Fidelity with JumpReLU Sparse Autoencoders

TL;DR

This paper tackles the sparsity-versus-fidelity trade-off in sparse autoencoders applied to language model activations. It introduces JumpReLU SAEs, a simple modification of vanilla SAEs that uses a per-feature threshold with JumpReLU activation and trains via straight-through-estimators to optimize an sparsity objective. Empirically, JumpReLU SAEs achieve state-of-the-art reconstruction fidelity at fixed sparsity on Gemma 2 9B activations and maintain interpretability comparable to existing approaches like Gated and TopK SAEs. The work demonstrates efficient training, broad evaluation across layers and sites, and provides a principled approach to training with discontinuous activations that could generalize to other discontinuous loss functions and model families.

Abstract

Sparse autoencoders (SAEs) are a promising unsupervised approach for identifying causally relevant and interpretable linear features in a language model's (LM) activations. To be useful for downstream tasks, SAEs need to decompose LM activations faithfully; yet to be interpretable the decomposition must be sparse -- two objectives that are in tension. In this paper, we introduce JumpReLU SAEs, which achieve state-of-the-art reconstruction fidelity at a given sparsity level on Gemma 2 9B activations, compared to other recent advances such as Gated and TopK SAEs. We also show that this improvement does not come at the cost of interpretability through manual and automated interpretability studies. JumpReLU SAEs are a simple modification of vanilla (ReLU) SAEs -- where we replace the ReLU with a discontinuous JumpReLU activation function -- and are similarly efficient to train and run. By utilising straight-through-estimators (STEs) in a principled manner, we show how it is possible to train JumpReLU SAEs effectively despite the discontinuous JumpReLU function introduced in the SAE's forward pass. Similarly, we use STEs to directly train L0 to be sparse, instead of training on proxies such as L1, avoiding problems like shrinkage.
Paper Structure (40 sections, 3 theorems, 39 equations, 17 figures)

This paper contains 40 sections, 3 theorems, 39 equations, 17 figures.

Key Result

Lemma 1

Let $\mathbf{X}$ be a $n$-dimensional real random variable with probability density $p_\mathbf{X}$ and let $Y = g(\mathbf{X})$ for a differentiable function $g:\mathbb{R}^n \to \mathbb{R}$. Then we can express the probability density function of $Y$ as the surface integral where $\partial V(y)$ is the surface $g(\mathbf{x}) = y$ and $\mathrm{d} S$ is its surface element.

Figures (17)

  • Figure 1: A toy model illustrating why JumpReLU (or similar activation functions, such as TopK) are an improvement over ReLU for training sparse yet faithful SAEs. Consider a direction in which the encoder pre-activation is high when the corresponding feature is active and low, but not always negative, when the feature is inactive (far-left). Applying a ReLU activation function fails to remove all false positives (centre-left), harming sparsity. It is possible to get rid of false positives while maintaining the ReLU, e.g. by decreasing the encoder bias (centre-right), but this leads to feature magnitudes being systematically underestimated, harming fidelity. The JumpReLU activation function (far-right) provides an independent threshold below which pre-activations are screened out, minimising false positives, while leaving pre-activations above the threshold unaffected, improving fidelity.
  • Figure 2: JumpReLU SAEs offer reconstruction fidelity that equals or exceeds Gated and TopK SAEs at a fixed level of sparsity. These results are for SAEs trained on the residual stream after layers 9, 20 and 31 of Gemma 2 9B. See \ref{['fig:paretos-mlp']} and \ref{['fig:paretos-attn']} for analogous plots for SAEs trained on MLP and attention output activations at these layers.
  • Figure 3: The JumpReLU activation function zeroes inputs below the threshold, $\theta$, and is an identity function for inputs above the threshold.
  • Figure 4: The JumpReLU activation function (left) and the Heaviside step function (right) used to calculate the sparsity penalty are piecewise constant with respect to the JumpReLU threshold. Therefore, in order to be able to train a JumpReLU SAE, we define the pseudo-derivatives illustrated in these plots and defined in \ref{['eq:jr-ste']} and \ref{['eq:step-ste']}, which approximate the Dirac delta functions present in the actual (weak) derivatives of the JumpReLU and Heaviside functions. These pseudo-derivatives provide a gradient signal to the threshold whenever pre-activations are within a small window of width $\varepsilon$ around the threshold. Note these plots show the profile of these pseudo-derivatives in the $z$, not $\theta$ direction, as $z$ is the stochastic input that is averaged over when computing the mean gradient.
  • Figure 5: The proportion of features that activate very frequently versus delta LM loss by SAE type for Gemma 2 9B residual stream SAEs. TopK and JumpReLU SAEs tend to have relatively more very high frequency features -- those active on over 10% of tokens (top) -- than Gated SAEs. If we instead count features that are active on over 1% of tokens (bottom), the picture is more mixed: Gated SAEs can have more of these high (but not necessarily very high) features than JumpReLU SAEs, particularly in the low loss (and therefore lower sparsity) regime.
  • ...and 12 more figures

Theorems & Definitions (6)

  • Lemma 1
  • proof
  • Theorem 1
  • proof
  • Lemma 2
  • proof