Wasserstein Distances, Neuronal Entanglement, and Sparsity
Shashata Sawmya, Linghao Kong, Ilia Markov, Dan Alistarh, Nir Shavit
TL;DR
The paper tackles how neuronal entanglement affects performance under weight sparsity in large language models. It introduces the Wasserstein distance to a Gaussian, $W_1(n', \mathcal{N})$, as a metric for entanglement and identifies a small set of highly entangled 'Wasserstein neurons' whose irregular output distributions critically influence accuracy. To study disentanglement, the authors propose Sparse Expansion, a one-shot, input-aware framework that creates a mixture of sparse experts via layer-wise input clustering and pruning with SparseGPT, enabling deeper IO-disentanglement without retraining. They show that WD is the best predictor of improvement from disentanglement, achieving state-of-the-art sparsification performance across multiple model families and scales, and they discuss theoretical bounds on computation under entanglement, with implications for mechanistic interpretability and entanglement-aware sparsification strategies.
Abstract
Disentangling polysemantic neurons is at the core of many current approaches to interpretability of large language models. Here we attempt to study how disentanglement can be used to understand performance, particularly under weight sparsity, a leading post-training optimization technique. We suggest a novel measure for estimating neuronal entanglement: the Wasserstein distance of a neuron's output distribution to a Gaussian. Moreover, we show the existence of a small number of highly entangled "Wasserstein Neurons" in each linear layer of an LLM, characterized by their highly non-Gaussian output distributions, their role in mapping similar inputs to dissimilar outputs, and their significant impact on model accuracy. To study these phenomena, we propose a new experimental framework for disentangling polysemantic neurons. Our framework separates each layer's inputs to create a mixture of experts where each neuron's output is computed by a mixture of neurons of lower Wasserstein distance, each better at maintaining accuracy when sparsified without retraining. We provide strong evidence that this is because the mixture of sparse experts is effectively disentangling the input-output relationship of individual neurons, in particular the difficult Wasserstein neurons.
