Table of Contents
Fetching ...

Wasserstein Distances, Neuronal Entanglement, and Sparsity

Shashata Sawmya, Linghao Kong, Ilia Markov, Dan Alistarh, Nir Shavit

TL;DR

The paper tackles how neuronal entanglement affects performance under weight sparsity in large language models. It introduces the Wasserstein distance to a Gaussian, $W_1(n', \mathcal{N})$, as a metric for entanglement and identifies a small set of highly entangled 'Wasserstein neurons' whose irregular output distributions critically influence accuracy. To study disentanglement, the authors propose Sparse Expansion, a one-shot, input-aware framework that creates a mixture of sparse experts via layer-wise input clustering and pruning with SparseGPT, enabling deeper IO-disentanglement without retraining. They show that WD is the best predictor of improvement from disentanglement, achieving state-of-the-art sparsification performance across multiple model families and scales, and they discuss theoretical bounds on computation under entanglement, with implications for mechanistic interpretability and entanglement-aware sparsification strategies.

Abstract

Disentangling polysemantic neurons is at the core of many current approaches to interpretability of large language models. Here we attempt to study how disentanglement can be used to understand performance, particularly under weight sparsity, a leading post-training optimization technique. We suggest a novel measure for estimating neuronal entanglement: the Wasserstein distance of a neuron's output distribution to a Gaussian. Moreover, we show the existence of a small number of highly entangled "Wasserstein Neurons" in each linear layer of an LLM, characterized by their highly non-Gaussian output distributions, their role in mapping similar inputs to dissimilar outputs, and their significant impact on model accuracy. To study these phenomena, we propose a new experimental framework for disentangling polysemantic neurons. Our framework separates each layer's inputs to create a mixture of experts where each neuron's output is computed by a mixture of neurons of lower Wasserstein distance, each better at maintaining accuracy when sparsified without retraining. We provide strong evidence that this is because the mixture of sparse experts is effectively disentangling the input-output relationship of individual neurons, in particular the difficult Wasserstein neurons.

Wasserstein Distances, Neuronal Entanglement, and Sparsity

TL;DR

The paper tackles how neuronal entanglement affects performance under weight sparsity in large language models. It introduces the Wasserstein distance to a Gaussian, , as a metric for entanglement and identifies a small set of highly entangled 'Wasserstein neurons' whose irregular output distributions critically influence accuracy. To study disentanglement, the authors propose Sparse Expansion, a one-shot, input-aware framework that creates a mixture of sparse experts via layer-wise input clustering and pruning with SparseGPT, enabling deeper IO-disentanglement without retraining. They show that WD is the best predictor of improvement from disentanglement, achieving state-of-the-art sparsification performance across multiple model families and scales, and they discuss theoretical bounds on computation under entanglement, with implications for mechanistic interpretability and entanglement-aware sparsification strategies.

Abstract

Disentangling polysemantic neurons is at the core of many current approaches to interpretability of large language models. Here we attempt to study how disentanglement can be used to understand performance, particularly under weight sparsity, a leading post-training optimization technique. We suggest a novel measure for estimating neuronal entanglement: the Wasserstein distance of a neuron's output distribution to a Gaussian. Moreover, we show the existence of a small number of highly entangled "Wasserstein Neurons" in each linear layer of an LLM, characterized by their highly non-Gaussian output distributions, their role in mapping similar inputs to dissimilar outputs, and their significant impact on model accuracy. To study these phenomena, we propose a new experimental framework for disentangling polysemantic neurons. Our framework separates each layer's inputs to create a mixture of experts where each neuron's output is computed by a mixture of neurons of lower Wasserstein distance, each better at maintaining accuracy when sparsified without retraining. We provide strong evidence that this is because the mixture of sparse experts is effectively disentangling the input-output relationship of individual neurons, in particular the difficult Wasserstein neurons.
Paper Structure (35 sections, 2 equations, 24 figures, 5 tables, 2 algorithms)

This paper contains 35 sections, 2 equations, 24 figures, 5 tables, 2 algorithms.

Figures (24)

  • Figure 1: The output distributions of neurons in Llama-2-7B computed densely and at 90% sparsity on Wikitext-2. WD refers to the Wasserstein distance of the output distribution to a Gaussian. RI refers to the relative improvement of Sparse Expansion over SparseGPT. (a) The dense output distribution of a random neuron with a WD of 0.050 is well captured by SparseGPT, and (b) expanding this neuron via Sparse Expansion imparts only a small (18%) increase in performance. (c) The cluster outputs are all concentrated in close proximity to each other. (d) SparseGPT struggles to capture the dense distribution of an entangled neuron with a WD of 0.524. (e) Following expansion, the sparse output of the entangled neuron is much better captured, leading to more improvement (77%). (f) Each expert specializes over a different portion of the distribution.
  • Figure 2: A measure of neuronal entanglement. (a) The output distribution of a random neuron. (b) The normalized $L^2$ plot of a random neuron's pairs of inputs and outputs. (c) The output distribution of a Wasserstein neuron. (d) The normalized $L^2$ plot of a Wasserstein neuron's pairs of inputs and outputs. This neuron must map fairly similar inputs to outputs that are very far apart through its dot product operation. The neurons are from the up projection matrix of the second FFN block in Pythia-1.4B. (e) The MD of a neuron is highly correlated with its WD. The selected random and Wasserstein neurons are highlighted in their respective colors.
  • Figure 3: Entangled neurons are much more sensitive to compression. In Llama-3-8B, 3% of neurons from every FFN linear layer are sparsified via SparseGPT in an unstructured manner with a subset of the Wikitext-2 train dataset as calibration data. (a) Sparsifying Wasserstein neurons (blue) impairs the model more than sparsifying neurons with the highest output distribution means (orange) and variances (green), those with the highest average mean weight magnitude (purple), and considerably more than random neurons (red). Perplexity is measured on the Wikitext-2 test dataset. (b-d) Sparsifying the Wasserstein neurons (blue) affects general and mathematical reasoning much more than random neurons (red), as shown in the capability charts. At higher levels of neuron sparsity ($\geq 95\%$), ablating Wasserstein neurons leads to a collapse in model performance, which does not occur with random neurons.
  • Figure 4: The Sparse Expansion process. One-shot expert creation process of Sparse Expansion (left). Inference process in a FFN of an expanded model (right).
  • Figure 5: Sparse Expansion recovers performance of Wasserstein neurons. (a) Although Wasserstein neurons are penalized more under sparsity, they also recover better in Sparse Expansion compared to random neurons. We quantify this recovery using normalized perplexity relative to the dense model. Data from Llama-3-8B. (b) As a result of Sparse Expansion, the median decrease in WD per neuron is 19%. Although a few neurons with an initially low dense WD exhibit a higher average weighted WD, the majority (68%) of all neurons show a decrease in weighted WD. This is especially true in the top 10% of neurons with an originally high WD—the Wasserstein neurons. (c) Sparse Expansion also decreases the weighted MD by a median of 2% per neuron. 70% of all neurons and 96% of Wasserstein neurons show a decrease in weighted MD, the latter with a median decrease of 9% per neuron. (b, c) Data collected from of the up projection matrix in the second FFN of Pythia-1.4B.
  • ...and 19 more figures