Table of Contents
Fetching ...

Improving Neuron-level Interpretability with White-box Language Models

Hao Bai, Yi Ma

TL;DR

This work addresses the challenge of neuron-level interpretability in autoregressive language models by moving sparsity from post-hoc tooling into the model itself. It introduces CRATE, a white-box, sparse-coding language model that integrates a MSSA compression module and an overcomplete ISTA sparsification module into a transformer-like architecture, tailored for next-token prediction. Across multiple model sizes (1–12 layers) and extensive pretraining on the Pile, CRATE delivers substantial improvements in interpretability over GPT-2, with gains up to 103% on OpenAI TaR/Random-only and Anthropic metrics, while maintaining competitive predictive performance. The findings underscore the potential of built-in sparse coding for robust, layer-spanning neuron-level explanations and steerability, while also highlighting future work needed to balance interpretability with raw predictive accuracy and to extend these mechanisms to downstream model edits.

Abstract

Neurons in auto-regressive language models like GPT-2 can be interpreted by analyzing their activation patterns. Recent studies have shown that techniques such as dictionary learning, a form of post-hoc sparse coding, enhance this neuron-level interpretability. In our research, we are driven by the goal to fundamentally improve neural network interpretability by embedding sparse coding directly within the model architecture, rather than applying it as an afterthought. In our study, we introduce a white-box transformer-like architecture named Coding RAte TransformEr (CRATE), explicitly engineered to capture sparse, low-dimensional structures within data distributions. Our comprehensive experiments showcase significant improvements (up to 103% relative improvement) in neuron-level interpretability across a variety of evaluation metrics. Detailed investigations confirm that this enhanced interpretability is steady across different layers irrespective of the model size, underlining CRATE's robust performance in enhancing neural network interpretability. Further analysis shows that CRATE's increased interpretability comes from its enhanced ability to consistently and distinctively activate on relevant tokens. These findings point towards a promising direction for creating white-box foundation models that excel in neuron-level interpretation.

Improving Neuron-level Interpretability with White-box Language Models

TL;DR

This work addresses the challenge of neuron-level interpretability in autoregressive language models by moving sparsity from post-hoc tooling into the model itself. It introduces CRATE, a white-box, sparse-coding language model that integrates a MSSA compression module and an overcomplete ISTA sparsification module into a transformer-like architecture, tailored for next-token prediction. Across multiple model sizes (1–12 layers) and extensive pretraining on the Pile, CRATE delivers substantial improvements in interpretability over GPT-2, with gains up to 103% on OpenAI TaR/Random-only and Anthropic metrics, while maintaining competitive predictive performance. The findings underscore the potential of built-in sparse coding for robust, layer-spanning neuron-level explanations and steerability, while also highlighting future work needed to balance interpretability with raw predictive accuracy and to extend these mechanisms to downstream model edits.

Abstract

Neurons in auto-regressive language models like GPT-2 can be interpreted by analyzing their activation patterns. Recent studies have shown that techniques such as dictionary learning, a form of post-hoc sparse coding, enhance this neuron-level interpretability. In our research, we are driven by the goal to fundamentally improve neural network interpretability by embedding sparse coding directly within the model architecture, rather than applying it as an afterthought. In our study, we introduce a white-box transformer-like architecture named Coding RAte TransformEr (CRATE), explicitly engineered to capture sparse, low-dimensional structures within data distributions. Our comprehensive experiments showcase significant improvements (up to 103% relative improvement) in neuron-level interpretability across a variety of evaluation metrics. Detailed investigations confirm that this enhanced interpretability is steady across different layers irrespective of the model size, underlining CRATE's robust performance in enhancing neural network interpretability. Further analysis shows that CRATE's increased interpretability comes from its enhanced ability to consistently and distinctively activate on relevant tokens. These findings point towards a promising direction for creating white-box foundation models that excel in neuron-level interpretation.

Paper Structure

This paper contains 37 sections, 14 equations, 14 figures, 9 tables, 3 algorithms.

Figures (14)

  • Figure 1: Instances are systematically identified where the interpretability of crate (ours, row 1) outperforms GPT-2 (row 2). For each neuron (rounded box), we show two top activated text excerpts (excerpt 1 and 2) and one randomly activated excerpt (excerpt 3). Results show that crate consistently activates on and only on semantically relevant text excerpts (first two excerpts), leading to more precise explanations predicted by agents like Mistral.
  • Figure 2: Block architecture for the crate language model, where $S_{\lambda}(x) = \operatorname{ReLU}(x - \eta\cdot\lambda\cdot 1)$. Differences from the original architecture mentioned in yu2023white are marked bold: we (i) add a causal mask $\text{Mask}(\cdot)$ and (ii) over-parameterize the ISTA block.
  • Figure 3: crate iteratively compresses (MSSA block) and sparsifies (ISTA block) the token representations (colored points) across its layers from $1$ to $L$, transforming them into parsimonious representations aligned on axes (colored lines) with distinct semantic meanings.
  • Figure 4: Left: loss curve when pre-training crate-Base and GPT2-Base on the Pile dataset. Right: zero-shot validation loss of crate evaluated on a variety of datasets (Pile, LAMBADA, OpenWebText and WikiText).
  • Figure 5: Left: Validation loss of crate compared to GPT-2 on the Pile dataset, with respect to the model size. Right: Qualitative examples of predictions made by crate and GPT-2. The tokens in blue are considered good. We compare crate-Base to GPT2-Base on the next word prediction task.
  • ...and 9 more figures