Table of Contents
Fetching ...

Revisiting End-To-End Sparse Autoencoder Training: A Short Finetune Is All You Need

Adam Karvonen

TL;DR

The paper shows that a brief KL+MSE fine-tuning stage applied to the final portion of SAE training yields cross-entropy reductions comparable to full end-to-end training, at roughly half the computational cost. By dynamically balancing KL and MSE losses, this approach transfers hyperparameters efficiently without modifying model architecture. Across TopK and ReLU SAEs, the method improves reconstruction fidelity but yields mixed results on SAEBench interpretability metrics, varying with architecture and task. The work offers a practical recipe for more interpretable, sparsity-efficient activations with minimal overhead, and highlights trade-offs between reconstruction quality and downstream interpretability benchmarks in real-world tasks like circuit analysis.

Abstract

Sparse autoencoders (SAEs) are widely used for interpreting language model activations. A key evaluation metric is the increase in cross-entropy loss between the original model logits and the reconstructed model logits when replacing model activations with SAE reconstructions. Typically, SAEs are trained solely on mean squared error (MSE) when reconstructing precomputed, shuffled activations. Recent work introduced training SAEs directly with a combination of KL divergence and MSE ("end-to-end" SAEs), significantly improving reconstruction accuracy at the cost of substantially increased computation, which has limited their widespread adoption. We propose a brief KL+MSE fine-tuning step applied only to the final 25M training tokens (just a few percent of typical training budgets) that achieves comparable improvements, reducing the cross-entropy loss gap by 20-50%, while incurring minimal additional computational cost. We further find that multiple fine-tuning methods (KL fine-tuning, LoRA adapters, linear adapters) yield similar, non-additive cross-entropy improvements, suggesting a common, easily correctable error source in MSE-trained SAEs. We demonstrate a straightforward method for effectively transferring hyperparameters and sparsity penalties between training phases despite scale differences between KL and MSE losses. While both ReLU and TopK SAEs see significant cross-entropy loss improvements, evaluations on supervised SAEBench metrics yield mixed results, with improvements on some metrics and decreases on others, depending on both the SAE architecture and downstream task. Nonetheless, our method may offer meaningful improvements in interpretability applications such as circuit analysis with minor additional cost.

Revisiting End-To-End Sparse Autoencoder Training: A Short Finetune Is All You Need

TL;DR

The paper shows that a brief KL+MSE fine-tuning stage applied to the final portion of SAE training yields cross-entropy reductions comparable to full end-to-end training, at roughly half the computational cost. By dynamically balancing KL and MSE losses, this approach transfers hyperparameters efficiently without modifying model architecture. Across TopK and ReLU SAEs, the method improves reconstruction fidelity but yields mixed results on SAEBench interpretability metrics, varying with architecture and task. The work offers a practical recipe for more interpretable, sparsity-efficient activations with minimal overhead, and highlights trade-offs between reconstruction quality and downstream interpretability benchmarks in real-world tasks like circuit analysis.

Abstract

Sparse autoencoders (SAEs) are widely used for interpreting language model activations. A key evaluation metric is the increase in cross-entropy loss between the original model logits and the reconstructed model logits when replacing model activations with SAE reconstructions. Typically, SAEs are trained solely on mean squared error (MSE) when reconstructing precomputed, shuffled activations. Recent work introduced training SAEs directly with a combination of KL divergence and MSE ("end-to-end" SAEs), significantly improving reconstruction accuracy at the cost of substantially increased computation, which has limited their widespread adoption. We propose a brief KL+MSE fine-tuning step applied only to the final 25M training tokens (just a few percent of typical training budgets) that achieves comparable improvements, reducing the cross-entropy loss gap by 20-50%, while incurring minimal additional computational cost. We further find that multiple fine-tuning methods (KL fine-tuning, LoRA adapters, linear adapters) yield similar, non-additive cross-entropy improvements, suggesting a common, easily correctable error source in MSE-trained SAEs. We demonstrate a straightforward method for effectively transferring hyperparameters and sparsity penalties between training phases despite scale differences between KL and MSE losses. While both ReLU and TopK SAEs see significant cross-entropy loss improvements, evaluations on supervised SAEBench metrics yield mixed results, with improvements on some metrics and decreases on others, depending on both the SAE architecture and downstream task. Nonetheless, our method may offer meaningful improvements in interpretability applications such as circuit analysis with minor additional cost.

Paper Structure

This paper contains 26 sections, 7 equations, 14 figures.

Figures (14)

  • Figure 1: Comparison of training approaches for a sparse autoencoder (K=80, width=16K) on Pythia-160M. The proposed KL+MSE fine-tuning approach (25M tokens) achieves slightly better performance than full end-to-end (E2E) training braun2024identifyingfunctionallyimportantfeatures on the same amount of data while reducing wall-clock time by approximately 50%.
  • Figure 2: Comparison of KL+MSE finetuning (25M tokens) vs full end-to-end training (E2E) on Gemma-2-2B with 65K width SAEs.
  • Figure 3: Comparison of KL+MSE finetuning (25M tokens) vs LoRA adapters on Gemma-2-2B with 65K width SAEs.
  • Figure 4: SAEBench results for KL+MSE finetuning (15M tokens) on 65k width SAEs on Gemma-2-2B.
  • Figure 5: Comparison of training with KL+MSE loss versus KL-only loss. There is virtually no difference in validation loss between the two methods, while KL only shows significantly worse MSE on the training set.
  • ...and 9 more figures