Table of Contents
Fetching ...

Time-Aware Feature Selection: Adaptive Temporal Masking for Stable Sparse Autoencoder Training

T. Ed Li, Junyu Ren

TL;DR

This paper tackles the instability of sparse autoencoders (SAEs) used for interpreting large language models by addressing feature absorption. It proposes Adaptive Temporal Masking (ATM), which replaces fixed sparsity with a time-aware, probabilistic masking scheme guided by an Importance score computed from exponential moving averages of activation magnitude, activation frequency, and reconstruction contribution. Key contributions include a principled, adaptive thresholding rule $\theta_t = \mu_t + c\sigma_t$ and a masking mechanism with probability $p(\text{masked}) = 1 - \exp(-r(\theta_t - \text{Importance}(t))/\theta_t)$, integrated into a training objective that jointly minimizes reconstruction error and sparsity penalties. Empirically, ATM achieves substantially lower feature absorption than TopK and JumpReLU baselines while maintaining reconstruction quality, and yields strong sparse probing performance across 35 binary tasks, thereby enabling more stable and interpretable representations for model analysis.

Abstract

Understanding the internal representations of large language models is crucial for ensuring their reliability and safety, with sparse autoencoders (SAEs) emerging as a promising interpretability approach. However, current SAE training methods face feature absorption, where features (or neurons) are absorbed into each other to minimize $L_1$ penalty, making it difficult to consistently identify and analyze model behaviors. We introduce Adaptive Temporal Masking (ATM), a novel training approach that dynamically adjusts feature selection by tracking activation magnitudes, frequencies, and reconstruction contributions to compute importance scores that evolve over time. ATM applies a probabilistic masking mechanism based on statistical thresholding of these importance scores, creating a more natural feature selection process. Through extensive experiments on the Gemma-2-2b model, we demonstrate that ATM achieves substantially lower absorption scores compared to existing methods like TopK and JumpReLU SAEs, while maintaining excellent reconstruction quality. These results establish ATM as a principled solution for learning stable, interpretable features in neural networks, providing a foundation for more reliable model analysis.

Time-Aware Feature Selection: Adaptive Temporal Masking for Stable Sparse Autoencoder Training

TL;DR

This paper tackles the instability of sparse autoencoders (SAEs) used for interpreting large language models by addressing feature absorption. It proposes Adaptive Temporal Masking (ATM), which replaces fixed sparsity with a time-aware, probabilistic masking scheme guided by an Importance score computed from exponential moving averages of activation magnitude, activation frequency, and reconstruction contribution. Key contributions include a principled, adaptive thresholding rule and a masking mechanism with probability , integrated into a training objective that jointly minimizes reconstruction error and sparsity penalties. Empirically, ATM achieves substantially lower feature absorption than TopK and JumpReLU baselines while maintaining reconstruction quality, and yields strong sparse probing performance across 35 binary tasks, thereby enabling more stable and interpretable representations for model analysis.

Abstract

Understanding the internal representations of large language models is crucial for ensuring their reliability and safety, with sparse autoencoders (SAEs) emerging as a promising interpretability approach. However, current SAE training methods face feature absorption, where features (or neurons) are absorbed into each other to minimize penalty, making it difficult to consistently identify and analyze model behaviors. We introduce Adaptive Temporal Masking (ATM), a novel training approach that dynamically adjusts feature selection by tracking activation magnitudes, frequencies, and reconstruction contributions to compute importance scores that evolve over time. ATM applies a probabilistic masking mechanism based on statistical thresholding of these importance scores, creating a more natural feature selection process. Through extensive experiments on the Gemma-2-2b model, we demonstrate that ATM achieves substantially lower absorption scores compared to existing methods like TopK and JumpReLU SAEs, while maintaining excellent reconstruction quality. These results establish ATM as a principled solution for learning stable, interpretable features in neural networks, providing a foundation for more reliable model analysis.

Paper Structure

This paper contains 15 sections, 7 equations, 1 figure, 1 table.

Figures (1)

  • Figure 1: Visualization of feature absorption. Panel A represents the target scenario, where the SAE learns two features in two neurons: "starts with E" (blue) and "elephant" (red). When the underlying token is <elephant>, both neurons should light up resulting in an overall purple activation vector for token <elephant> . However, panel B reveals what the SAE actually learns due to $L_1$ loss: the "elephant" feature can absorb the "starts with E" feature, which effectively reduces the number of active latents (lower $L_1$ norm) when the underlying token is <elephant>. While this increases sparsity, it diminishes interpretability since the "starts with E" feature no longer activates independently. Instead, the "elephant" feature acquires an unintended downstream effect, making feature activations less modular. The figure is adapted from https://colab.research.google.com/drive/1ePkM8oBHIEZ2kcqAiA3waeAmz8RSdHmqcolab2025.