Table of Contents
Fetching ...

Fluctuation-dissipation relations for stochastic gradient descent

Sho Yaida

TL;DR

Problem: relate minibatch noise in SGD to parameter dynamics during stationary training. Approach: derive exact, stationarity-based fluctuation-dissipation relations using a discrete-time master-equation framework that accommodates non-Gaussian noise and nonconvex landscapes. Key findings: FDR1 provides a practical equilibration metric and adaptive learning-rate schedule; FDR2 enables probing the loss landscape via the Hessian and anharmonicity. Empirical validation on MNIST and CIFAR-10 confirms the relations and demonstrates the practical utility of adaptive scheduling.

Abstract

The notion of the stationary equilibrium ensemble has played a central role in statistical mechanics. In machine learning as well, training serves as generalized equilibration that drives the probability distribution of model parameters toward stationarity. Here, we derive stationary fluctuation-dissipation relations that link measurable quantities and hyperparameters in the stochastic gradient descent algorithm. These relations hold exactly for any stationary state and can in particular be used to adaptively set training schedule. We can further use the relations to efficiently extract information pertaining to a loss-function landscape such as the magnitudes of its Hessian and anharmonicity. Our claims are empirically verified.

Fluctuation-dissipation relations for stochastic gradient descent

TL;DR

Problem: relate minibatch noise in SGD to parameter dynamics during stationary training. Approach: derive exact, stationarity-based fluctuation-dissipation relations using a discrete-time master-equation framework that accommodates non-Gaussian noise and nonconvex landscapes. Key findings: FDR1 provides a practical equilibration metric and adaptive learning-rate schedule; FDR2 enables probing the loss landscape via the Hessian and anharmonicity. Empirical validation on MNIST and CIFAR-10 confirms the relations and demonstrates the practical utility of adaptive scheduling.

Abstract

The notion of the stationary equilibrium ensemble has played a central role in statistical mechanics. In machine learning as well, training serves as generalized equilibration that drives the probability distribution of model parameters toward stationarity. Here, we derive stationary fluctuation-dissipation relations that link measurable quantities and hyperparameters in the stochastic gradient descent algorithm. These relations hold exactly for any stationary state and can in particular be used to adaptively set training schedule. We can further use the relations to efficiently extract information pertaining to a loss-function landscape such as the magnitudes of its Hessian and anharmonicity. Our claims are empirically verified.

Paper Structure

This paper contains 21 sections, 31 equations, 6 figures.

Figures (6)

  • Figure 1: Approaches toward stationarity during the initial trainings for the MLP on the MNIST data (a) and for the CNN on the CIFAR-10 data (b). Top panels depict the half-running average $\overline{f^{\mathcal{B}}}(t)$ (dark green) and the instantaneous value $f^{\mathcal{B}}(t)$ (light green) of the mini-batch loss. Bottom panels depict the convergence of the half-running averages of the observables $\mathcal{O}_{\mathrm{L}}=\bm{\theta}\cdot\bm{\nabla} f^{\mathcal{B}}$ and $\mathcal{O}_{\mathrm{R}}=\frac{(1+\mu)}{2(1-\nu)}\eta \mathbf{v}^2$, whose stationary-state averages should agree according to the relation (\ref{['FDR1G']}).
  • Figure 2: Approaches toward stationarity during the sequential runs for various learning rates $\eta$, seen through the half-running averages of the observables $\mathcal{O}_{\mathrm{L}}=\bm{\theta}\cdot\bm{\nabla} f^{\mathcal{B}}$ (solid) and $\mathcal{O}_{\mathrm{R}}=\frac{(1+\mu)}{2(1-\nu)}\eta \mathbf{v}^2$ (dotted light-colored). They agree at sufficiently long times but the relaxation time to reach such a stationary regime increases as the learning rate $\eta$ decreases.
  • Figure 3: The stationary-state average of the full-batch observable $\mathcal{O}_{\mathrm{FB}}$ as a function of the learning rate $\eta$, estimated through half-running averages. Dots and error bars denote mean values and $95\%$ confidence intervals over several distinct runs, respectively. The straight red line connects the origin and the point with the smallest $\eta$ explored. (a) For the MLP on the MNIST data, linear dependence on $\eta$ for $\eta\lesssim0.01$ supports the validity of the harmonic approximation there. (b) For the CNN on the CIFAR-10 data, anharmonicity is pronounced even down to $\eta\sim0.001$.
  • Figure 4: Comparison of preset training schedule (black) and adaptive training schedule (blue), employing SGD without momentum both for the MLP on the MNIST data (a) and the CNN on the CIFAR-10 data (b), along with the AMSGrad algorithm (green). From top to bottom, plotted are the learning rate $\eta$, the full-batch training loss $f$, and prediction accuracies on the training-set images (solid) and the $10000$ test-set images (dashed).
  • Figure S1: Comparison of AMSGrad (green) and Adam (orange) algorithms for the MLP on the MNIST data (a) and the CNN on the CIFAR-10 data (b). Top rows plot the full-batch training loss $f$ while bottom rows plot prediction accuracies on the training-set images (solid) and the $10000$ test-set images (dashed).
  • ...and 1 more figures