Table of Contents
Fetching ...

Rigorous dynamical mean field theory for stochastic gradient descent methods

Cedric Gerbelot, Emanuele Troiani, Francesca Mignacco, Florent Krzakala, Lenka Zdeborova

TL;DR

This work addresses the exact high-dimensional behavior of first-order gradient-based methods such as SGD, Langevin dynamics, and momentum methods on Gaussian data. Using iterative Gaussian conditioning, it derives a discrete-time dynamical mean-field theory (DMFT) that expresses the dynamics through memory kernels and Gaussian processes with covariances, leading to self-consistent equations for all time steps. The main contributions include handling stochastic gradient noise, non-separable updates, and general data covariance, together with a numerically tractable solver for the DMFT equations and demonstrations on SGD variants. The results provide a principled framework to analyze convergence and stability of training dynamics in the high-dimensional regime, bridging statistical learning and dynamical mean-field theory.

Abstract

We prove closed-form equations for the exact high-dimensional asymptotics of a family of first order gradient-based methods, learning an estimator (e.g. M-estimator, shallow neural network, ...) from observations on Gaussian data with empirical risk minimization. This includes widely used algorithms such as stochastic gradient descent (SGD) or Nesterov acceleration. The obtained equations match those resulting from the discretization of dynamical mean-field theory (DMFT) equations from statistical physics when applied to gradient flow. Our proof method allows us to give an explicit description of how memory kernels build up in the effective dynamics, and to include non-separable update functions, allowing datasets with non-identity covariance matrices. Finally, we provide numerical implementations of the equations for SGD with generic extensive batch-size and with constant learning rates.

Rigorous dynamical mean field theory for stochastic gradient descent methods

TL;DR

This work addresses the exact high-dimensional behavior of first-order gradient-based methods such as SGD, Langevin dynamics, and momentum methods on Gaussian data. Using iterative Gaussian conditioning, it derives a discrete-time dynamical mean-field theory (DMFT) that expresses the dynamics through memory kernels and Gaussian processes with covariances, leading to self-consistent equations for all time steps. The main contributions include handling stochastic gradient noise, non-separable updates, and general data covariance, together with a numerically tractable solver for the DMFT equations and demonstrations on SGD variants. The results provide a principled framework to analyze convergence and stability of training dynamics in the high-dimensional regime, bridging statistical learning and dynamical mean-field theory.

Abstract

We prove closed-form equations for the exact high-dimensional asymptotics of a family of first order gradient-based methods, learning an estimator (e.g. M-estimator, shallow neural network, ...) from observations on Gaussian data with empirical risk minimization. This includes widely used algorithms such as stochastic gradient descent (SGD) or Nesterov acceleration. The obtained equations match those resulting from the discretization of dynamical mean-field theory (DMFT) equations from statistical physics when applied to gradient flow. Our proof method allows us to give an explicit description of how memory kernels build up in the effective dynamics, and to include non-separable update functions, allowing datasets with non-identity covariance matrices. Finally, we provide numerical implementations of the equations for SGD with generic extensive batch-size and with constant learning rates.
Paper Structure (27 sections, 8 theorems, 139 equations, 3 figures)

This paper contains 27 sections, 8 theorems, 139 equations, 3 figures.

Key Result

Theorem 3.2

(High-dimensional dynamics of gradient-based methods) \newlabelth:main_dmft0 Consider the following discrete time stochastic process initialized with $\bm{\nu}^{0} = \bm{v}^{0}$, where $\bm{u}^{t},\bm{\omega}^{t}$ have i.i.d. lines in $\mathbb{R}^{q}$ which are Gaussian processes with covariances $C_{g}^{s,t}, C_{\theta}^{s,t}$. In the above, the notation $\frac{\partial {\bar{g}}^{t}_{i}}{\pa

Figures (3)

  • Figure 1: Gradient descent with sample splitting where $f'(z) = \hbox{tanh}(z)$ Due to the regularity of the update function and sample splitting assumption, the concentration is very fast and almost perfect matching is obtained between the theoretical and empirical curves with low dimensions (n=50,d=100) and no averaging.
  • Figure 1: Average cosine similarity with the signal as a function of time for different values of the learning rate $\gamma$ (left panel) and batch size $b$ (right panel). Parameters $\lambda=1$, $\alpha=0.9$, $b=0.2$ on the left, and $\gamma=0.04$ on the right. Continuous pale lines: solution of the DMFT equations in the high-dimensional limit. Dots: simulations with $d=1000$. On the left: different colors indicate different learning rates $\gamma$ with $b=0.2$.
  • Figure 2: Evolution of the magnetization obtained from the DMFT equations as the algorithm iterates (lines). We fix the parameters ratio of number of samples per dimension $\alpha=3$, regularization $\lambda=0.5$, the learning rate $\eta=0.1$, the mini-batch size $b=1$, the initial magnetization is $0.2$. The stochastic process in the DMFT equations is sampled more than $2500$ times for each iteration. We average the new proposal with the kernels with the previous values, keeping $70\%$ of the new kernel and $30\%$ of the old ones. Points: magnetization from SGD simulations on a dataset with dimension $d=1000$.

Theorems & Definitions (10)

  • Definition 3.1: pseudo-Lipschitz function
  • Theorem 3.2
  • Corollary 3.3
  • Theorem 4.1
  • Lemma A.1: Gaussian matrices under linear constraints
  • Lemma A.2: Gaussian concentration of pseudo-Lipschitz functions
  • Lemma A.3: Stein's lemma, matrix version
  • Lemma A.4: Miscellaneous results on Gaussian random matrices
  • Lemma A.5
  • Proof 1