Table of Contents
Fetching ...

Generalisation dynamics of online learning in over-parameterised neural networks

Sebastian Goldt, Madhu S. Advani, Andrew M. Saxe, Florent Krzakala, Lenka Zdeborová

TL;DR

This work analyzes why over-parameterised two-layer neural networks can generalise well by studying a teacher–student setup and deriving an ODE description of online SGD dynamics via order parameters. It shows that the asymptotic generalisation error scales linearly with the excess number of hidden units $L=K-M$, e.g. $\epsilon_g^* \sim \eta \sigma^2 L$ for small $\eta$, and that this scaling persists across sigmoidal, linear, and ReLU activations, though via different mechanisms (specialisation versus redundancy). The results imply that SGD alone does not regularise over-parameterised models and that regularisation depends on the interplay between the optimization algorithm, learning rate, model architecture, and data, including finite-data effects and mini-batch settings. Overall, the paper provides a principled, quantitative framework for predicting generalisation dynamics in over-parameterised two-layer networks and highlights directions for improving generalisation beyond plain SGD.

Abstract

Deep neural networks achieve stellar generalisation on a variety of problems, despite often being large enough to easily fit all their training data. Here we study the generalisation dynamics of two-layer neural networks in a teacher-student setup, where one network, the student, is trained using stochastic gradient descent (SGD) on data generated by another network, called the teacher. We show how for this problem, the dynamics of SGD are captured by a set of differential equations. In particular, we demonstrate analytically that the generalisation error of the student increases linearly with the network size, with other relevant parameters held constant. Our results indicate that achieving good generalisation in neural networks depends on the interplay of at least the algorithm, its learning rate, the model architecture, and the data set.

Generalisation dynamics of online learning in over-parameterised neural networks

TL;DR

This work analyzes why over-parameterised two-layer neural networks can generalise well by studying a teacher–student setup and deriving an ODE description of online SGD dynamics via order parameters. It shows that the asymptotic generalisation error scales linearly with the excess number of hidden units , e.g. for small , and that this scaling persists across sigmoidal, linear, and ReLU activations, though via different mechanisms (specialisation versus redundancy). The results imply that SGD alone does not regularise over-parameterised models and that regularisation depends on the interplay between the optimization algorithm, learning rate, model architecture, and data, including finite-data effects and mini-batch settings. Overall, the paper provides a principled, quantitative framework for predicting generalisation dynamics in over-parameterised two-layer networks and highlights directions for improving generalisation beyond plain SGD.

Abstract

Deep neural networks achieve stellar generalisation on a variety of problems, despite often being large enough to easily fit all their training data. Here we study the generalisation dynamics of two-layer neural networks in a teacher-student setup, where one network, the student, is trained using stochastic gradient descent (SGD) on data generated by another network, called the teacher. We show how for this problem, the dynamics of SGD are captured by a set of differential equations. In particular, we demonstrate analytically that the generalisation error of the student increases linearly with the network size, with other relevant parameters held constant. Our results indicate that achieving good generalisation in neural networks depends on the interplay of at least the algorithm, its learning rate, the model architecture, and the data set.

Paper Structure

This paper contains 32 sections, 53 equations, 13 figures.

Figures (13)

  • Figure 1: Neural network with a single hidden layer. A network with $K$ hidden units and weights $w$ implements a scalar function of its inputs $x$, $y=\sum_k^K g(w_k x)$, where $g: \mathbb{R}\to\mathbb{R}$ is the non-linear activation function of the network.
  • Figure 2: The analytical description of the generalisation dynamics of sigmoidal networks (solid) matches simulations (crosses). We show learning curves $\epsilon_g(\alpha)$ obtained by integration of the ODEs \ref{['eq:eom']} (solid). From left to the right, we vary the variance of the teacher's output noise $\sigma$, the learning rate $\eta$, and the number of hidden units in the student $K$. For each combination of parameters shown in the plots, we ran a single simulation of a network with $N=784$ and plot the generalisation observed (crosses). $\kappa=0$ in all cases.
  • Figure 3: Theoretical predictions for $\epsilon_g^*$ match simulations. We plot theoretical predictions for $\epsilon_g^*/\sigma^2$ for sigmoidal networks (Eq. \ref{['eq:egFinal']}, solid line) and linear networks (Eq. \ref{['eq:eg-lin']}, dashed) together with the result from a single simulation of a network with $N=784$. Parameters: $\eta=0.05, \sigma=0.01$.
  • Figure 4: Sigmoidal networks learn different representations from a noisy teacher than ReLU networks. For sigmoidal networks (left), we see clear signs of specialisation as described in Sec. \ref{['sec:sigmoidal-network']}: for $K=3$, one unit is simply shut down: $w_1=0$. As we increase $K$, some units become exactly anti-correlated to each other (e.g. units 1, 4 for $K=5$), hence effectively setting their weights to zero (we keep the weights of the second layer fixed at unity). ReLU networks instead find representations where several nodes are used. Parameters: $N=784, \eta=0.3, \sigma=0.1, \kappa=0$.
  • Figure 5: The final generalisation error of over-parametrised ReLU networks scales as $\epsilon_g^* \sim \eta \sigma^2L$. Simulations confirm that the asymptotic generalisation error $\epsilon_g^*$ of a ReLU student learning from a ReLU teacher scales with the learning rate $\eta$, the variance of the teacher's output noise $\sigma^2$ and the number of additional hidden units as $\epsilon_g\sim \eta \sigma^2L$, which is the same scaling as the one found analytically for sigmoidal networks in Eq. \ref{['eq:egFinal1stOrderInLr']}. Straight lines are linear fits to the data, with slope $1$ in (a) and (b). Parameters: $M=2, K=6$ (a, b) and $M=4, 16$; $K=M + L$ (c); in all figures, $N=784, \kappa=0$.
  • ...and 8 more figures