Table of Contents
Fetching ...

Low-Rank Adapting Models for Sparse Autoencoders

Matthew Chen, Joshua Engels, Max Tegmark

TL;DR

This work tackles interpretability of language models by leveraging Sparse Autoencoders (SAEs) to decompose hidden representations and proposes finetuning around a fixed SAE using Low-Rank Adaptation (LoRA). By freezing both the SAE and the base model and training low-rank adapters, the approach reduces the SAE loss gap $L_{\text{SAE}}-L_{\text{BASE}}$ by roughly $30\%-55\%$, while delivering 2×–20× speedups over end-to-end SAE training and enabling multiple SAEs to be adapted concurrently without harming general capabilities. Across diverse SAE families (TopK, JumpReLU) and model scales, the method yields consistent downstream benefits, as evidenced by SAEBench metrics and steering evaluations, and maintains general-domain performance on MMLU, HellaSwag, and TruthfulQA. The findings suggest that Pareto improvements in interpretability can be achieved not only through post-hoc decomposition but also via targeted, parameter-efficient modifications to the underlying language model when SAEs are present.

Abstract

Sparse autoencoders (SAEs) decompose language model representations into a sparse set of linear latent vectors. Recent works have improved SAEs using language model gradients, but these techniques require many expensive backward passes during training and still cause a significant increase in cross entropy loss when SAE reconstructions are inserted into the model. In this work, we improve on these limitations by taking a fundamentally different approach: we use low-rank adaptation (LoRA) to finetune the \textit{language model itself} around a previously trained SAE. We analyze our method across SAE sparsity, SAE width, language model size, LoRA rank, and model layer on the Gemma Scope family of SAEs. In these settings, our method reduces the cross entropy loss gap by 30\% to 55\% when SAEs are inserted during the forward pass. We also find that compared to end-to-end (e2e) SAEs, our approach achieves the same downstream cross entropy loss 3$\times$ to 20$\times$ faster on \gemma and 2$\times$ to 10$\times$ faster on \llama. We further show that our technique improves downstream metrics and can adapt multiple SAEs at once without harming general language model capabilities. Our results demonstrate that improving model interpretability is not limited to post-hoc SAE training; Pareto improvements can also be achieved by directly optimizing the model itself.

Low-Rank Adapting Models for Sparse Autoencoders

TL;DR

This work tackles interpretability of language models by leveraging Sparse Autoencoders (SAEs) to decompose hidden representations and proposes finetuning around a fixed SAE using Low-Rank Adaptation (LoRA). By freezing both the SAE and the base model and training low-rank adapters, the approach reduces the SAE loss gap by roughly , while delivering 2×–20× speedups over end-to-end SAE training and enabling multiple SAEs to be adapted concurrently without harming general capabilities. Across diverse SAE families (TopK, JumpReLU) and model scales, the method yields consistent downstream benefits, as evidenced by SAEBench metrics and steering evaluations, and maintains general-domain performance on MMLU, HellaSwag, and TruthfulQA. The findings suggest that Pareto improvements in interpretability can be achieved not only through post-hoc decomposition but also via targeted, parameter-efficient modifications to the underlying language model when SAEs are present.

Abstract

Sparse autoencoders (SAEs) decompose language model representations into a sparse set of linear latent vectors. Recent works have improved SAEs using language model gradients, but these techniques require many expensive backward passes during training and still cause a significant increase in cross entropy loss when SAE reconstructions are inserted into the model. In this work, we improve on these limitations by taking a fundamentally different approach: we use low-rank adaptation (LoRA) to finetune the \textit{language model itself} around a previously trained SAE. We analyze our method across SAE sparsity, SAE width, language model size, LoRA rank, and model layer on the Gemma Scope family of SAEs. In these settings, our method reduces the cross entropy loss gap by 30\% to 55\% when SAEs are inserted during the forward pass. We also find that compared to end-to-end (e2e) SAEs, our approach achieves the same downstream cross entropy loss 3 to 20 faster on \gemma and 2 to 10 faster on \llama. We further show that our technique improves downstream metrics and can adapt multiple SAEs at once without harming general language model capabilities. Our results demonstrate that improving model interpretability is not limited to post-hoc SAE training; Pareto improvements can also be achieved by directly optimizing the model itself.

Paper Structure

This paper contains 28 sections, 10 equations, 11 figures, 6 tables.

Figures (11)

  • Figure 1: Cross entropy loss vs. training time over 2B tokens for Gemma-2-2B TopK SAEs with $\mathrm{width} = 18,432, \mathrm{L_0} = 64$. We find that our method (TopK + LoRA in the plot) outperforms an e2e SAE and vanilla TopK SAE.
  • Figure 2: Visual representation of our method, with a local SAE trained on layer $12$ and low-rank adapters trained on MLP and attention components on all layers.
  • Figure 3: Cross entropy loss improvement (Top: absolute, Bottom: percentage of CE loss gap closed) using our method for Gemma Scope SAEs on Gemma-2-2B. Left: Scaling across sparsity with fixed width=16k and layer=12, we see the largest effect by percentage of our method at lower sparsities, but still substantial effect at higher sparsities as well. Middle: Scaling across width with fixed $L_0=68$ and layer=12, the highest effect by percentage is at low width but again this is not a large effect. Right: Scaling across layer with fixed $L_0=68$ and width=16k, the highest effect of our method by percentage is at layer $9$ but it is mostly unaffected by layer.
  • Figure 4: Cross entropy loss improvement (Top: absolute, Bottom: percentage) for Gemma Scope SAEs of width $16k$ and $L_0$ closest to $70$ on Gemma-2-2B, 7B, and 27B. We find that our method works increasingly well on larger models.
  • Figure 5: Cross entropy loss vs. training time for Llama-3.2-1B with TopK SAEs of $L_0 = 64$ and width 16384. Our method (TopK + LoRA) achieves lower CE loss sooner than e2e SAE or vanilla TopK SAEs
  • ...and 6 more figures