Learning time-scales in two-layers neural networks
Raphaël Berthier, Andrea Montanari, Kangjie Zhou
TL;DR
This work analyzes gradient-flow dynamics of wide two-layer neural networks under a single-index data model, revealing a structured, multiscale learning process. By reducing the high-dimensional dynamics to a mean-field, $d$-independent flow and applying singular perturbation theory, the authors show that learning proceeds incrementally along Hermite polynomial components of the target function, with distinct time scales controlled by the learning-rate parameter $\varepsilon$. They establish a connection to mean-field PDEs, demonstrate the existence of a zero-risk global infimum under generic conditions, and provide finite-sample SGD guarantees that align with the canonical learning order. The results offer a principled picture of plateaus, waterfalls, and incremental generalization in deep learning, with implications for understanding implicit biases and designing training protocols that exploit feature learning in early phases.
Abstract
Gradient-based learning in multi-layer neural networks displays a number of striking features. In particular, the decrease rate of empirical risk is non-monotone even after averaging over large batches. Long plateaus in which one observes barely any progress alternate with intervals of rapid decrease. These successive phases of learning often take place on very different time scales. Finally, models learnt in an early phase are typically `simpler' or `easier to learn' although in a way that is difficult to formalize. Although theoretical explanations of these phenomena have been put forward, each of them captures at best certain specific regimes. In this paper, we study the gradient flow dynamics of a wide two-layer neural network in high-dimension, when data are distributed according to a single-index model (i.e., the target function depends on a one-dimensional projection of the covariates). Based on a mixture of new rigorous results, non-rigorous mathematical derivations, and numerical simulations, we propose a scenario for the learning dynamics in this setting. In particular, the proposed evolution exhibits separation of timescales and intermittency. These behaviors arise naturally because the population gradient flow can be recast as a singularly perturbed dynamical system.
