SGD learning on neural networks: leap complexity and saddle-to-saddle dynamics
Emmanuel Abbe, Enric Boix-Adsera, Theodor Misiakiewicz
TL;DR
This work investigates the time complexity of SGD learning on regular fully-connected neural networks trained on isotropic data with low latent dimensionality. It introduces the leap complexity to quantify function hierarchy and proves, for a Gaussian setting with a 2-layer network, that SGD learns a target in time scaling as latTheta(d^{max(Leap(f),2)}) up to poly(1/ε), via a saddle-to-saddle sequential learning dynamics. It establishes CSQ lower bounds that match the proposed upper bounds, and frames SGD as implementing an adaptive curriculum that learns low-level features first and builds up to higher-order monomials. The results generalize prior leap-1 analyses and provide a rigorous connection between practical SGD dynamics and information-theoretic lower bounds, with experimental evidence and clear avenues for extension to broader architectures and data distributions.
Abstract
We investigate the time complexity of SGD learning on fully-connected neural networks with isotropic data. We put forward a complexity measure -- the leap -- which measures how "hierarchical" target functions are. For $d$-dimensional uniform Boolean or isotropic Gaussian data, our main conjecture states that the time complexity to learn a function $f$ with low-dimensional support is $\tildeΘ(d^{\max(\mathrm{Leap}(f),2)})$. We prove a version of this conjecture for a class of functions on Gaussian isotropic data and 2-layer neural networks, under additional technical assumptions on how SGD is run. We show that the training sequentially learns the function support with a saddle-to-saddle dynamic. Our result departs from [Abbe et al. 2022] by going beyond leap 1 (merged-staircase functions), and by going beyond the mean-field and gradient flow approximations that prohibit the full complexity control obtained here. Finally, we note that this gives an SGD complexity for the full training trajectory that matches that of Correlational Statistical Query (CSQ) lower-bounds.
