Table of Contents
Fetching ...

Neural Networks Learn Statistics of Increasing Complexity

Nora Belrose, Quintin Pope, Lucia Quirke, Alex Mallen, Xiaoli Fern

TL;DR

This work presents compelling new evidence for the distributional simplicity bias by showing that networks automatically learn to perform well on maximum-entropy distributions whose low-order statistics match those of the training set early in training, then lose this ability later.

Abstract

The distributional simplicity bias (DSB) posits that neural networks learn low-order moments of the data distribution first, before moving on to higher-order correlations. In this work, we present compelling new evidence for the DSB by showing that networks automatically learn to perform well on maximum-entropy distributions whose low-order statistics match those of the training set early in training, then lose this ability later. We also extend the DSB to discrete domains by proving an equivalence between token $n$-gram frequencies and the moments of embedding vectors, and by finding empirical evidence for the bias in LLMs. Finally we use optimal transport methods to surgically edit the low-order statistics of one class to match those of another, and show that early-training networks treat the edited samples as if they were drawn from the target class. Code is available at https://github.com/EleutherAI/features-across-time.

Neural Networks Learn Statistics of Increasing Complexity

TL;DR

This work presents compelling new evidence for the distributional simplicity bias by showing that networks automatically learn to perform well on maximum-entropy distributions whose low-order statistics match those of the training set early in training, then lose this ability later.

Abstract

The distributional simplicity bias (DSB) posits that neural networks learn low-order moments of the data distribution first, before moving on to higher-order correlations. In this work, we present compelling new evidence for the DSB by showing that networks automatically learn to perform well on maximum-entropy distributions whose low-order statistics match those of the training set early in training, then lose this ability later. We also extend the DSB to discrete domains by proving an equivalence between token -gram frequencies and the moments of embedding vectors, and by finding empirical evidence for the bias in LLMs. Finally we use optimal transport methods to surgically edit the low-order statistics of one class to match those of another, and show that early-training networks treat the edited samples as if they were drawn from the target class. Code is available at https://github.com/EleutherAI/features-across-time.
Paper Structure (43 sections, 6 theorems, 15 equations, 17 figures, 1 algorithm)

This paper contains 43 sections, 6 theorems, 15 equations, 17 figures, 1 algorithm.

Key Result

Theorem 2.1

[$n$-gram statistics are moments] Let $\mathcal{V}^N$ be the set of token sequences of length $N$ drawn from a finite vocabulary $\mathcal{V}$, let $P$ be a distribution on $\mathcal{V}^N$, and let $f : \mathcal{V}^{N} \rightarrow \{0, 1\}^{N \cdot | \mathcal{V} |}$ be the function that encodes a le

Figures (17)

  • Figure 1: (left) Pekinese dog image from the ImageNet training set. (center) Image after quantile normalizing its pixels to match the marginal distribution of the goldfish class on ImageNet. The grass is now a slightly darker shade of green and the dog's fur has a reddish hue. (right) Synthetic "goldfish" generated by sampling each pixel independently from its marginal distribution.
  • Figure 2: Non-cherrypicked "fake" images produced by maximum entropy sampling using only the first two moments of the class-conditional distributions, and a hypercube constraint. Fake MNIST digits are clearly recognizable, SVHN digits less so, whereas fake CIFAR-10 images look nothing like their respective classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck.
  • Figure 3: Accuracy of computer vision models when evaluated on images edited with optimal transport maps as described in Sec. \ref{['sec:ot']}, using the target class, not the source class, as the label. Between roughly $2^4$ and $2^{12}$ training steps, all models classify the CQN-edited images coming from target class, with a peak in accuracy at $2^{9}$.
  • Figure 4: Rows 1-3 show how Gaussian optimal transport affects the example CIFAR-10 airplane, bird and truck images. Each row starts with the original unedited image on the left, with each subsequent column showing the effects of editing that image's first two moments to match the class-conditional distributions of a particular target class (top).
  • Figure 5: Accuracy of computer vision models being trained on the standard CIFAR-10 training set, and being evaluated on maximum-entropy synthetic data with matching statistics of varying order.
  • ...and 12 more figures

Theorems & Definitions (11)

  • Theorem 2.1
  • Theorem 2.2
  • Theorem 4.1
  • proof
  • Theorem 5.1
  • proof
  • Definition 6.1: $n$-gram statistic
  • Theorem 6.1
  • proof
  • Theorem 6.1
  • ...and 1 more