Revisiting End-To-End Sparse Autoencoder Training: A Short Finetune Is All You Need
Adam Karvonen
TL;DR
The paper shows that a brief KL+MSE fine-tuning stage applied to the final portion of SAE training yields cross-entropy reductions comparable to full end-to-end training, at roughly half the computational cost. By dynamically balancing KL and MSE losses, this approach transfers hyperparameters efficiently without modifying model architecture. Across TopK and ReLU SAEs, the method improves reconstruction fidelity but yields mixed results on SAEBench interpretability metrics, varying with architecture and task. The work offers a practical recipe for more interpretable, sparsity-efficient activations with minimal overhead, and highlights trade-offs between reconstruction quality and downstream interpretability benchmarks in real-world tasks like circuit analysis.
Abstract
Sparse autoencoders (SAEs) are widely used for interpreting language model activations. A key evaluation metric is the increase in cross-entropy loss between the original model logits and the reconstructed model logits when replacing model activations with SAE reconstructions. Typically, SAEs are trained solely on mean squared error (MSE) when reconstructing precomputed, shuffled activations. Recent work introduced training SAEs directly with a combination of KL divergence and MSE ("end-to-end" SAEs), significantly improving reconstruction accuracy at the cost of substantially increased computation, which has limited their widespread adoption. We propose a brief KL+MSE fine-tuning step applied only to the final 25M training tokens (just a few percent of typical training budgets) that achieves comparable improvements, reducing the cross-entropy loss gap by 20-50%, while incurring minimal additional computational cost. We further find that multiple fine-tuning methods (KL fine-tuning, LoRA adapters, linear adapters) yield similar, non-additive cross-entropy improvements, suggesting a common, easily correctable error source in MSE-trained SAEs. We demonstrate a straightforward method for effectively transferring hyperparameters and sparsity penalties between training phases despite scale differences between KL and MSE losses. While both ReLU and TopK SAEs see significant cross-entropy loss improvements, evaluations on supervised SAEBench metrics yield mixed results, with improvements on some metrics and decreases on others, depending on both the SAE architecture and downstream task. Nonetheless, our method may offer meaningful improvements in interpretability applications such as circuit analysis with minor additional cost.
