Table of Contents
Fetching ...

SGD on Neural Networks Learns Functions of Increasing Complexity

Preetum Nakkiran, Gal Kaplun, Dimitris Kalimeris, Tristan Yang, Benjamin L. Edelman, Fred Zhang, Boaz Barak

TL;DR

This work investigates why SGD-trained deep neural networks generalize well in overparameterized regimes. It introduces a mutual-information based framework to quantify how much of a network's early performance can be attributed to a simple linear classifier, revealing a two-phase learning dynamic where a simple predictor explains early gains and is retained as learning progresses to more complex functions. The authors provide extensive experimental evidence across datasets and architectures, plus a simple theoretical result showing that starting from a simple, generalizable predictor can yield good population accuracy even as training fits the data. Together, these results offer an information-theoretic lens on SGD inductive bias and lay groundwork for understanding phase-wise learning and generalization in deep nets.

Abstract

We perform an experimental study of the dynamics of Stochastic Gradient Descent (SGD) in learning deep neural networks for several real and synthetic classification tasks. We show that in the initial epochs, almost all of the performance improvement of the classifier obtained by SGD can be explained by a linear classifier. More generally, we give evidence for the hypothesis that, as iterations progress, SGD learns functions of increasing complexity. This hypothesis can be helpful in explaining why SGD-learned classifiers tend to generalize well even in the over-parameterized regime. We also show that the linear classifier learned in the initial stages is "retained" throughout the execution even if training is continued to the point of zero training error, and complement this with a theoretical result in a simplified model. Key to our work is a new measure of how well one classifier explains the performance of another, based on conditional mutual information.

SGD on Neural Networks Learns Functions of Increasing Complexity

TL;DR

This work investigates why SGD-trained deep neural networks generalize well in overparameterized regimes. It introduces a mutual-information based framework to quantify how much of a network's early performance can be attributed to a simple linear classifier, revealing a two-phase learning dynamic where a simple predictor explains early gains and is retained as learning progresses to more complex functions. The authors provide extensive experimental evidence across datasets and architectures, plus a simple theoretical result showing that starting from a simple, generalizable predictor can yield good population accuracy even as training fits the data. Together, these results offer an information-theoretic lens on SGD inductive bias and lay groundwork for understanding phase-wise learning and generalization in deep nets.

Abstract

We perform an experimental study of the dynamics of Stochastic Gradient Descent (SGD) in learning deep neural networks for several real and synthetic classification tasks. We show that in the initial epochs, almost all of the performance improvement of the classifier obtained by SGD can be explained by a linear classifier. More generally, we give evidence for the hypothesis that, as iterations progress, SGD learns functions of increasing complexity. This hypothesis can be helpful in explaining why SGD-learned classifiers tend to generalize well even in the over-parameterized regime. We also show that the linear classifier learned in the initial stages is "retained" throughout the execution even if training is continued to the point of zero training error, and complement this with a theoretical result in a simplified model. Key to our work is a new measure of how well one classifier explains the performance of another, based on conditional mutual information.

Paper Structure

This paper contains 23 sections, 3 theorems, 12 equations, 8 figures, 3 tables.

Key Result

Theorem 1

Consider training a linear classifier via minimizing the empirical square loss using SGD. Let $\varepsilon > 0$ be a small constant and let the initial vector $\bm{w}_0$ satisfy $\bm{w}_0(1) \geq -n^{0.99}$, and $|\bm{w}_0(i)| \leq 1 - 2p - \varepsilon$ for all $i>1$. Then, with high probability, sa

Figures (8)

  • Figure 1: Left: An illustration of our hypothesis of how SGD dynamics progress. Initially, all progress in learning can be attributed to a "simple" classifier (in some precise sense to be later defined), then SGD continues in learning more complex but still meaningful classifiers. Finally, the classifier will interpolate the training data, while retaining correlation with simpler classifiers that allows it to generalize. Right: A plot of how the decision boundary evolves as a neural network is trained for a simple classification task. The data distribution is uniform in a 2-dimensional ball of radius 1, labeled by a sinusoidal curve with 10% label noise. It is evident that an almost linear decision boundary emerges in the first phases of training before more complex classifiers are learned. In the last stages, the network overfits to the label noise, while still retaining the concept.
  • Figure 2: Beyond linear classifiers. The two phases of SGD learning in Figure \ref{['fig:phasesoflearning']} can be broken into several sub-phases. Phase $i$ involves learning classifiers of lower "complexity" than phase $i+1$. The precise notion of complexity may be algorithm, initialization and architecture-dependent. In practice, we expect that the phases will not be completely disjoint and some learning of classifiers of differing complexity will co-occur at the same time.
  • Figure 3: $I(F;Y)$ as a function of $\mathop{\mathrm{\mathbb{P}}}\limits[F=Y]$ for unbiased binary $F,Y$ s.t. $\mathop{\mathrm{\mathbb{P}}}\limits[ F=Y] \geq 1/2$.
  • Figure 4: SGD dynamics for various classification tasks. In each figure, we plot both the value of the mutual information and the corresponding accuracy. Observe that in the initial phases the bulk of the increase in performance is attributed to the linear classifier, since $\mu_Y(F; L) \approx I(F_t; Y)$.
  • Figure 5: Distinguishing between the first vs. the last 5 classes of CIFAR10. CNN$k$ denotes a convolutional neural network of $k$ layers. We clearly see a separation in phases of learning, where all curves $\mu_Y(F_t; G_i)$ are initially close to $I(F_t, Y)$, before each successively plateaus as training progresses. The plot matches the conjectured behavior illustrated in Figure \ref{['fig:manyphasecartoon']}.
  • ...and 3 more figures

Theorems & Definitions (11)

  • Claim 1: Informal
  • Claim 2: Informal
  • Remark 1: Beyond linear classifiers
  • Remark 2: Beyond binary classification
  • Definition 1
  • Theorem 1
  • proof : Proof sketch
  • Lemma 1: Convergence of gradient descent
  • proof
  • Theorem 1
  • ...and 1 more