Enhancing Neural Network Interpretability with Feature-Aligned Sparse Autoencoders
Luke Marks, Alasdair Paren, David Krueger, Fazl Barez
TL;DR
This work tackles the interpretability challenge of sparse autoencoders (SAEs) by introducing Mutual Feature Regularization (MFR), which trains multiple SAEs in parallel to learn common, input-aligned features. MFR combines a targeted reinitialization strategy for inactive TopK features with an auxiliary penalty that increases cross-SAE feature similarity, leading to better recovery of input features in both synthetic datasets and real-world activations from GPT-2 Small and EEG data. Empirically, MFR improves reconstruction loss (up to $21.21 ext{\%}$ on GPT-2 Small and $6.67 ext{\%}$ on EEG) and strengthens the alignment between SAE features and input features, supporting its potential for enhancing SAE interpretability. The results also reveal trade-offs in computational cost and parameter regime, highlighting avenues for more efficient mutual-learning strategies in future work.
Abstract
Sparse Autoencoders (SAEs) have shown promise in improving the interpretability of neural network activations, but can learn features that are not features of the input, limiting their effectiveness. We propose \textsc{Mutual Feature Regularization} \textbf{(MFR)}, a regularization technique for improving feature learning by encouraging SAEs trained in parallel to learn similar features. We motivate \textsc{MFR} by showing that features learned by multiple SAEs are more likely to correlate with features of the input. By training on synthetic data with known features of the input, we show that \textsc{MFR} can help SAEs learn those features, as we can directly compare the features learned by the SAE with the input features for the synthetic data. We then scale \textsc{MFR} to SAEs that are trained to denoise electroencephalography (EEG) data and SAEs that are trained to reconstruct GPT-2 Small activations. We show that \textsc{MFR} can improve the reconstruction loss of SAEs by up to 21.21\% on GPT-2 Small, and 6.67\% on EEG data. Our results suggest that the similarity between features learned by different SAEs can be leveraged to improve SAE training, thereby enhancing performance and the usefulness of SAEs for model interpretability.
