Table of Contents
Fetching ...

Towards Interpretability Without Sacrifice: Faithful Dense Layer Decomposition with Mixture of Decoders

James Oldfield, Shawn Im, Sharon Li, Mihalis A. Nicolaou, Ioannis Patras, Grigorios G Chrysos

TL;DR

This work targets the interpretability of dense MLP layers in large language models by introducing Mixture of Decoders (MxD), a layer-level sparse decomposition that preserves the original layer’s expressive power through a Hadamard-factorized weight tensor. The authors prove rank preservation for each expert, enabling faithful reconstruction under heavy sparsity, and derive a simple, efficient forward pass that equates to a Hadamard product of compact projections. Empirically, MxDs outperform state-of-the-art sparse MLP approximations on the sparsity-accuracy frontier across multiple models (up to 3B parameters) while maintaining competitive interpretability signals via sparse probing and feature steering. The results suggest that layer-level specialization can yield interpretable yet faithful decompositions, offering a path toward sparsity without sacrificing performance in practical language-model settings.

Abstract

Multilayer perceptrons (MLPs) are an integral part of large language models, yet their dense representations render them difficult to understand, edit, and steer. Recent methods learn interpretable approximations via neuron-level sparsity, yet fail to faithfully reconstruct the original mapping--significantly increasing model's next-token cross-entropy loss. In this paper, we advocate for moving to layer-level sparsity to overcome the accuracy trade-off in sparse layer approximation. Under this paradigm, we introduce Mixture of Decoders (MxDs). MxDs generalize MLPs and Gated Linear Units, expanding pre-trained dense layers into tens of thousands of specialized sublayers. Through a flexible form of tensor factorization, each sparsely activating MxD sublayer implements a linear transformation with full-rank weights--preserving the original decoders' expressive capacity even under heavy sparsity. Experimentally, we show that MxDs significantly outperform state-of-the-art methods (e.g., Transcoders) on the sparsity-accuracy frontier in language models with up to 3B parameters. Further evaluations on sparse probing and feature steering demonstrate that MxDs learn similarly specialized features of natural language--opening up a promising new avenue for designing interpretable yet faithful decompositions. Our code is included at: https://github.com/james-oldfield/MxD/.

Towards Interpretability Without Sacrifice: Faithful Dense Layer Decomposition with Mixture of Decoders

TL;DR

This work targets the interpretability of dense MLP layers in large language models by introducing Mixture of Decoders (MxD), a layer-level sparse decomposition that preserves the original layer’s expressive power through a Hadamard-factorized weight tensor. The authors prove rank preservation for each expert, enabling faithful reconstruction under heavy sparsity, and derive a simple, efficient forward pass that equates to a Hadamard product of compact projections. Empirically, MxDs outperform state-of-the-art sparse MLP approximations on the sparsity-accuracy frontier across multiple models (up to 3B parameters) while maintaining competitive interpretability signals via sparse probing and feature steering. The results suggest that layer-level specialization can yield interpretable yet faithful decompositions, offering a path toward sparsity without sacrificing performance in practical language-model settings.

Abstract

Multilayer perceptrons (MLPs) are an integral part of large language models, yet their dense representations render them difficult to understand, edit, and steer. Recent methods learn interpretable approximations via neuron-level sparsity, yet fail to faithfully reconstruct the original mapping--significantly increasing model's next-token cross-entropy loss. In this paper, we advocate for moving to layer-level sparsity to overcome the accuracy trade-off in sparse layer approximation. Under this paradigm, we introduce Mixture of Decoders (MxDs). MxDs generalize MLPs and Gated Linear Units, expanding pre-trained dense layers into tens of thousands of specialized sublayers. Through a flexible form of tensor factorization, each sparsely activating MxD sublayer implements a linear transformation with full-rank weights--preserving the original decoders' expressive capacity even under heavy sparsity. Experimentally, we show that MxDs significantly outperform state-of-the-art methods (e.g., Transcoders) on the sparsity-accuracy frontier in language models with up to 3B parameters. Further evaluations on sparse probing and feature steering demonstrate that MxDs learn similarly specialized features of natural language--opening up a promising new avenue for designing interpretable yet faithful decompositions. Our code is included at: https://github.com/james-oldfield/MxD/.

Paper Structure

This paper contains 63 sections, 2 theorems, 24 equations, 23 figures, 10 tables.

Key Result

Lemma 1

We can materialize linear expert $n$'s weight matrix as $\boldsymbol{\mathcal{W}}(n,:,:)=\mathbf{W}_n=\mathbf{D}\,\mathrm{diag}\left( {\mathbf{c}}_n \right)\in\mathbb{R}^{H\times O}$. Assuming $\mathrm{diag}\left( {\mathbf{c}}_n \right)\in\mathbb{R}^{O\times O}$ is a diagonal matrix with no zeros al

Figures (23)

  • Figure 1: Units of specialization for sparse layer variants: Neuron-level sparsity of existing sparse MLPs dunefsky2024transcoderspaulo2025transcoders (center) vs layer-level sparsity (right), which the proposed Mixture of Decoders (MxD) layer enables at scale. For GPT2, the dimensions are: $O=768$, $H^*=O\cdot 4$, $H\approx N\approx O\cdot32$.
  • Figure 2: Mixture of Decoders extends the base MLP/GLU layers with a conditional 'expert' branch, modulating the MLP's outputs.
  • Figure 3: Model cross-entropy loss preserved when replacing MLPs with Transcoders dunefsky2024transcoders, Skip Transcoders paulo2025transcoders, and MxDs, as a function of the number of active units $K$ (hidden neurons/experts). We highlight that MxDs have consistently lower loss at all levels of sparsity.
  • Figure 4: Proportion of $512$ generated samples that contain $n$ predicted future words identical to the original model's output when replacing the base LLM's MLP layer with the sparse layers.
  • Figure 5: Highest F1 score probing for 'news category' gulli2005agcorpus on individual features/experts. As expected, the MxDs remain competitive with the Transcoder baselines, outperforming TopK-SAEs.
  • ...and 18 more figures

Theorems & Definitions (4)

  • Lemma 1: Decoder rank preservation
  • Lemma 2: Hadamard-factorized MoE forward pass
  • proof : Proof of \ref{['lemma:rank']}
  • proof : Proof of \ref{['lem:moe']}