Table of Contents
Fetching ...

Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning

Dan Braun, Jordan Taylor, Nicholas Goldowsky-Dill, Lee Sharkey

TL;DR

The paper introduces end-to-end sparse dictionary learning (e2e SAEs) to identify functionally important features in neural networks by training sparse autoencoders to minimize the KL divergence between the original model outputs and outputs with SAE activations. Compared to local SAEs, e2e SAEs achieve a Pareto improvement, explaining more network performance with far fewer active features, albeit with higher per-layer reconstruction loss, which is mitigated by downstream reconstruction in the e2e+ds variant. The approach preserves or enhances interpretability, and experiments on GPT2-small and Tinystories-1M demonstrate robust gains in efficiency without sacrificing interpretability. The work provides an open-source library for training and analyzing e2e SAEs and advances the goal of concise, accurate explanations of network behavior.

Abstract

Identifying the features learned by neural networks is a core challenge in mechanistic interpretability. Sparse autoencoders (SAEs), which learn a sparse, overcomplete dictionary that reconstructs a network's internal activations, have been used to identify these features. However, SAEs may learn more about the structure of the datatset than the computational structure of the network. There is therefore only indirect reason to believe that the directions found in these dictionaries are functionally important to the network. We propose end-to-end (e2e) sparse dictionary learning, a method for training SAEs that ensures the features learned are functionally important by minimizing the KL divergence between the output distributions of the original model and the model with SAE activations inserted. Compared to standard SAEs, e2e SAEs offer a Pareto improvement: They explain more network performance, require fewer total features, and require fewer simultaneously active features per datapoint, all with no cost to interpretability. We explore geometric and qualitative differences between e2e SAE features and standard SAE features. E2e dictionary learning brings us closer to methods that can explain network behavior concisely and accurately. We release our library for training e2e SAEs and reproducing our analysis at https://github.com/ApolloResearch/e2e_sae

Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning

TL;DR

The paper introduces end-to-end sparse dictionary learning (e2e SAEs) to identify functionally important features in neural networks by training sparse autoencoders to minimize the KL divergence between the original model outputs and outputs with SAE activations. Compared to local SAEs, e2e SAEs achieve a Pareto improvement, explaining more network performance with far fewer active features, albeit with higher per-layer reconstruction loss, which is mitigated by downstream reconstruction in the e2e+ds variant. The approach preserves or enhances interpretability, and experiments on GPT2-small and Tinystories-1M demonstrate robust gains in efficiency without sacrificing interpretability. The work provides an open-source library for training and analyzing e2e SAEs and advances the goal of concise, accurate explanations of network behavior.

Abstract

Identifying the features learned by neural networks is a core challenge in mechanistic interpretability. Sparse autoencoders (SAEs), which learn a sparse, overcomplete dictionary that reconstructs a network's internal activations, have been used to identify these features. However, SAEs may learn more about the structure of the datatset than the computational structure of the network. There is therefore only indirect reason to believe that the directions found in these dictionaries are functionally important to the network. We propose end-to-end (e2e) sparse dictionary learning, a method for training SAEs that ensures the features learned are functionally important by minimizing the KL divergence between the output distributions of the original model and the model with SAE activations inserted. Compared to standard SAEs, e2e SAEs offer a Pareto improvement: They explain more network performance, require fewer total features, and require fewer simultaneously active features per datapoint, all with no cost to interpretability. We explore geometric and qualitative differences between e2e SAE features and standard SAE features. E2e dictionary learning brings us closer to methods that can explain network behavior concisely and accurately. We release our library for training e2e SAEs and reproducing our analysis at https://github.com/ApolloResearch/e2e_sae
Paper Structure (47 sections, 8 equations, 23 figures, 6 tables)

This paper contains 47 sections, 8 equations, 23 figures, 6 tables.

Figures (23)

  • Figure 1: Top: Diagram comparing the loss terms used to train each type of SAE. Each arrow is a loss term which compares the activations represented by circles. $\text{SAE}_{\text{local}}$ uses MSE reconstruction loss between the SAE input and the SAE output. $\text{SAE}_{\text{e2e}}$ uses KL-divergence on the logits. $\text{SAE}_{\text{e2e+ds}}$ (end-to-end $+$ downstream reconstruction) uses KL-divergence in addition to the sum of the MSE reconstruction losses at all future layers. All three are additionally trained with a $L_1$ sparsity penalty (not pictured). Bottom: Pareto curves for three different types of SAE as the sparsity coefficient is varied. E2e-SAEs require fewer features per datapoint (i.e. have a lower $L_0$) and fewer features over the entire dataset (i.e. have a low number of alive dictionary elements). GPT2-small has a CE loss of $3.139$ over our evaluation set.
  • Figure 2: Reconstruction mean squared error (MSE) at later layers for our set of GPT2-small layer $6$ SAEs with similar CE loss increases (Table \ref{['tab:similar_ce_loss_layer_6']}). $\text{SAE}_{\text{local}}$ is trained to minimize MSE at layer 6, $\text{SAE}_{\text{e2e}}$ was trained to match the output probability distribution, $\text{SAE}_{\text{e2e+ds}}$ was trained to match the output probability distribution and minimize MSE in all downstream layers.
  • Figure 3: Geometric comparisons for our set of GPT2-small layer $6$ SAEs with similar CE loss increases (Table \ref{['tab:similar_ce_loss_layer_6']}). For each dictionary element, we find the max cosine similarity between itself and all other dictionary elements. In \ref{['fig:within-sae-sim']} we compare to others directions in the same SAE, in \ref{['fig:cross-seed-sim']} to directions in an SAE of the same type trained with a different random seed, in \ref{['fig:cross_type_similarities_layer_6.png']} to directions in the $\text{SAE}_{\text{local}}$ with similar CE loss increase.
  • Figure 4: Performance of all SAE types on GPT2-small's residual stream at layers $2$, $6$ and $10$. GPT2-small has a CE loss of $3.139$ over our evaluation set.
  • Figure 5: Tinystories-1M runs comparing $\text{SAE}_{\text{local}}$, $\text{SAE}_{\text{e2e}}$ and $\text{SAE}_{\text{e2e+ds}}$ on the residual stream before the $5^\text{th}$ of $8$ layers. Tinystories-1M has a CE loss of $2.306$ over our evaluation set.
  • ...and 18 more figures