Table of Contents
Fetching ...

Enhancing Neural Network Interpretability with Feature-Aligned Sparse Autoencoders

Luke Marks, Alasdair Paren, David Krueger, Fazl Barez

TL;DR

This work tackles the interpretability challenge of sparse autoencoders (SAEs) by introducing Mutual Feature Regularization (MFR), which trains multiple SAEs in parallel to learn common, input-aligned features. MFR combines a targeted reinitialization strategy for inactive TopK features with an auxiliary penalty that increases cross-SAE feature similarity, leading to better recovery of input features in both synthetic datasets and real-world activations from GPT-2 Small and EEG data. Empirically, MFR improves reconstruction loss (up to $21.21 ext{\%}$ on GPT-2 Small and $6.67 ext{\%}$ on EEG) and strengthens the alignment between SAE features and input features, supporting its potential for enhancing SAE interpretability. The results also reveal trade-offs in computational cost and parameter regime, highlighting avenues for more efficient mutual-learning strategies in future work.

Abstract

Sparse Autoencoders (SAEs) have shown promise in improving the interpretability of neural network activations, but can learn features that are not features of the input, limiting their effectiveness. We propose \textsc{Mutual Feature Regularization} \textbf{(MFR)}, a regularization technique for improving feature learning by encouraging SAEs trained in parallel to learn similar features. We motivate \textsc{MFR} by showing that features learned by multiple SAEs are more likely to correlate with features of the input. By training on synthetic data with known features of the input, we show that \textsc{MFR} can help SAEs learn those features, as we can directly compare the features learned by the SAE with the input features for the synthetic data. We then scale \textsc{MFR} to SAEs that are trained to denoise electroencephalography (EEG) data and SAEs that are trained to reconstruct GPT-2 Small activations. We show that \textsc{MFR} can improve the reconstruction loss of SAEs by up to 21.21\% on GPT-2 Small, and 6.67\% on EEG data. Our results suggest that the similarity between features learned by different SAEs can be leveraged to improve SAE training, thereby enhancing performance and the usefulness of SAEs for model interpretability.

Enhancing Neural Network Interpretability with Feature-Aligned Sparse Autoencoders

TL;DR

This work tackles the interpretability challenge of sparse autoencoders (SAEs) by introducing Mutual Feature Regularization (MFR), which trains multiple SAEs in parallel to learn common, input-aligned features. MFR combines a targeted reinitialization strategy for inactive TopK features with an auxiliary penalty that increases cross-SAE feature similarity, leading to better recovery of input features in both synthetic datasets and real-world activations from GPT-2 Small and EEG data. Empirically, MFR improves reconstruction loss (up to on GPT-2 Small and on EEG) and strengthens the alignment between SAE features and input features, supporting its potential for enhancing SAE interpretability. The results also reveal trade-offs in computational cost and parameter regime, highlighting avenues for more efficient mutual-learning strategies in future work.

Abstract

Sparse Autoencoders (SAEs) have shown promise in improving the interpretability of neural network activations, but can learn features that are not features of the input, limiting their effectiveness. We propose \textsc{Mutual Feature Regularization} \textbf{(MFR)}, a regularization technique for improving feature learning by encouraging SAEs trained in parallel to learn similar features. We motivate \textsc{MFR} by showing that features learned by multiple SAEs are more likely to correlate with features of the input. By training on synthetic data with known features of the input, we show that \textsc{MFR} can help SAEs learn those features, as we can directly compare the features learned by the SAE with the input features for the synthetic data. We then scale \textsc{MFR} to SAEs that are trained to denoise electroencephalography (EEG) data and SAEs that are trained to reconstruct GPT-2 Small activations. We show that \textsc{MFR} can improve the reconstruction loss of SAEs by up to 21.21\% on GPT-2 Small, and 6.67\% on EEG data. Our results suggest that the similarity between features learned by different SAEs can be leveraged to improve SAE training, thereby enhancing performance and the usefulness of SAEs for model interpretability.

Paper Structure

This paper contains 14 sections, 9 equations, 10 figures, 3 tables.

Figures (10)

  • Figure 1: Our experimental pipeline for training SAEs with MFR. In step one, we extract activations from a neural network, represented by the interconnected nodes on the left. These activations are the inputs for our SAEs. In step two, we train multiple SAEs on the extracted activations. Each SAE learns to reconstruct the input activations through a sparsity constraint on the hidden layer. MFR involves several steps: We first check for inactive features in the SAE hidden state after applying the TopK activation function. If too many inactive features are detected, we reinitialize the weights of the affected SAE. We also include an auxiliary penalty to encourage the SAEs to learn similar features, shown by the final text box.
  • Figure 2: The relationship between feature similarity across SAEs, and feature similarity with the input features for two baseline SAEs.
  • Figure 3: The relationship between feature similarity across SAEs, and feature similarity with the input features for two SAEs with conditionally reinitialized weights.
  • Figure 4: The relationship between feature similarity across SAEs, feature similarity with the input features and the likelihood a feature is active after the TopK activation function on the hidden representation for two baseline SAEs.
  • Figure 5: The relationship between feature similarity across SAEs, feature similarity with the input features and the likelihood a feature is active after the TopK activation function on the hidden representation for two SAEs trained with MFR.
  • ...and 5 more figures