Towards Interpretability Without Sacrifice: Faithful Dense Layer Decomposition with Mixture of Decoders
James Oldfield, Shawn Im, Sharon Li, Mihalis A. Nicolaou, Ioannis Patras, Grigorios G Chrysos
TL;DR
This work targets the interpretability of dense MLP layers in large language models by introducing Mixture of Decoders (MxD), a layer-level sparse decomposition that preserves the original layer’s expressive power through a Hadamard-factorized weight tensor. The authors prove rank preservation for each expert, enabling faithful reconstruction under heavy sparsity, and derive a simple, efficient forward pass that equates to a Hadamard product of compact projections. Empirically, MxDs outperform state-of-the-art sparse MLP approximations on the sparsity-accuracy frontier across multiple models (up to 3B parameters) while maintaining competitive interpretability signals via sparse probing and feature steering. The results suggest that layer-level specialization can yield interpretable yet faithful decompositions, offering a path toward sparsity without sacrificing performance in practical language-model settings.
Abstract
Multilayer perceptrons (MLPs) are an integral part of large language models, yet their dense representations render them difficult to understand, edit, and steer. Recent methods learn interpretable approximations via neuron-level sparsity, yet fail to faithfully reconstruct the original mapping--significantly increasing model's next-token cross-entropy loss. In this paper, we advocate for moving to layer-level sparsity to overcome the accuracy trade-off in sparse layer approximation. Under this paradigm, we introduce Mixture of Decoders (MxDs). MxDs generalize MLPs and Gated Linear Units, expanding pre-trained dense layers into tens of thousands of specialized sublayers. Through a flexible form of tensor factorization, each sparsely activating MxD sublayer implements a linear transformation with full-rank weights--preserving the original decoders' expressive capacity even under heavy sparsity. Experimentally, we show that MxDs significantly outperform state-of-the-art methods (e.g., Transcoders) on the sparsity-accuracy frontier in language models with up to 3B parameters. Further evaluations on sparse probing and feature steering demonstrate that MxDs learn similarly specialized features of natural language--opening up a promising new avenue for designing interpretable yet faithful decompositions. Our code is included at: https://github.com/james-oldfield/MxD/.
