Table of Contents
Fetching ...

Control Reinforcement Learning: Token-Level Mechanistic Analysis via Learned SAE Feature Steering

Seonglae Cho, Zekun Wu, Adriano Koshiyama

TL;DR

This work proposes Control Reinforcement Learning (CRL), a framework that trains a policy to select sparse autoencoder (SAE) features to amplify at each token, yielding interpretable per-token intervention logs. By formulating the problem as an MDP over SAE features and employing adaptive feature masking with PPO optimization, CRL achieves both improved task performance and rich mechanistic diagnostics, including branch point analysis and critic trajectory insights. The method reveals layer-wise feature semantics (syntactic in early layers, semantic in later layers) and provides diagnostic tools such as intervention logs, critic analyses, and layer-wise comparisons. Empirically, CRL improves performance on Gemma-2 2B across MMLU, BBQ, GSM8K, HarmBench, and XSTest, and generalizes to LLaMA-3.1 8B, while offering interpretable, per-token logs that complement static feature analyses and static attribution methods with dynamic intervention probes.

Abstract

Sparse autoencoders (SAEs) decompose language model activations into interpretable features, but existing methods reveal only which features activate, not which change model outputs when amplified. We introduce Control Reinforcement Learning (CRL), which trains a policy to select SAE features for steering at each token, producing interpretable intervention logs: the learned policy identifies features that change model outputs when amplified. Adaptive Feature Masking encourages diverse feature discovery while preserving singlefeature interpretability. The framework yields new analysis capabilities: branch point tracking locates tokens where feature choice determines output correctness; critic trajectory analysis separates policy limitations from value estimation errors; layer-wise comparison reveals syntactic features in early layers and semantic features in later layers. On Gemma-2 2B across MMLU, BBQ, GSM8K, HarmBench, and XSTest, CRL achieves improvements while providing per-token intervention logs. These results establish learned feature steering as a mechanistic interpretability tool that complements static feature analysis with dynamic intervention probes

Control Reinforcement Learning: Token-Level Mechanistic Analysis via Learned SAE Feature Steering

TL;DR

This work proposes Control Reinforcement Learning (CRL), a framework that trains a policy to select sparse autoencoder (SAE) features to amplify at each token, yielding interpretable per-token intervention logs. By formulating the problem as an MDP over SAE features and employing adaptive feature masking with PPO optimization, CRL achieves both improved task performance and rich mechanistic diagnostics, including branch point analysis and critic trajectory insights. The method reveals layer-wise feature semantics (syntactic in early layers, semantic in later layers) and provides diagnostic tools such as intervention logs, critic analyses, and layer-wise comparisons. Empirically, CRL improves performance on Gemma-2 2B across MMLU, BBQ, GSM8K, HarmBench, and XSTest, and generalizes to LLaMA-3.1 8B, while offering interpretable, per-token logs that complement static feature analyses and static attribution methods with dynamic intervention probes.

Abstract

Sparse autoencoders (SAEs) decompose language model activations into interpretable features, but existing methods reveal only which features activate, not which change model outputs when amplified. We introduce Control Reinforcement Learning (CRL), which trains a policy to select SAE features for steering at each token, producing interpretable intervention logs: the learned policy identifies features that change model outputs when amplified. Adaptive Feature Masking encourages diverse feature discovery while preserving singlefeature interpretability. The framework yields new analysis capabilities: branch point tracking locates tokens where feature choice determines output correctness; critic trajectory analysis separates policy limitations from value estimation errors; layer-wise comparison reveals syntactic features in early layers and semantic features in later layers. On Gemma-2 2B across MMLU, BBQ, GSM8K, HarmBench, and XSTest, CRL achieves improvements while providing per-token intervention logs. These results establish learned feature steering as a mechanistic interpretability tool that complements static feature analysis with dynamic intervention probes
Paper Structure (48 sections, 11 equations, 23 figures, 6 tables, 1 algorithm)

This paper contains 48 sections, 11 equations, 23 figures, 6 tables, 1 algorithm.

Figures (23)

  • Figure 1: CRL overview. The policy observes residual stream activations and selects an SAE feature to amplify at each token. Each intervention produces an interpretable log: which feature was selected and how it affected output at that position. The critic estimates state values for PPO optimization.
  • Figure 2: Branch point analysis comparing layer 10 vs layer 20 feature semantics. Each panel shows context and competing features from both layers, with correct features highlighted.
  • Figure 3: Critic value distributions for single-token tasks. Top: MMLU, Bottom: BBQ. Colors: green = unchanged-correct, red = unchanged-incorrect, blue = corrected, yellow = misguided (see \ref{['subsec:terminology']} for definitions).
  • Figure 4: Critic network value trajectories for GSM8K task. Colors indicate: green (unchanged correct), red (unchanged incorrect), blue (corrected), yellow (misguided) (see \ref{['subsec:terminology']}).
  • Figure 5: GSM8K corrected cases. Top: Feature steering activates relevant numerical tokens, strengthening reasoning chains. Bottom: Semantically coherent feature activations align with equality tokens, guiding correct computation.
  • ...and 18 more figures