Table of Contents
Fetching ...

Online Learning and Information Exponents: On The Importance of Batch size, and Time/Complexity Tradeoffs

Luca Arnaboldi, Yatin Dandi, Florent Krzakala, Bruno Loureiro, Luca Pesce, Ludovic Stephan

TL;DR

This work characterize the optimal batch size minimizing the iteration time as a function of the hardness of the target, as characterized by the information exponents, and shows that performing gradient updates with large batches minimizes the training time without changing the total sample complexity.

Abstract

We study the impact of the batch size $n_b$ on the iteration time $T$ of training two-layer neural networks with one-pass stochastic gradient descent (SGD) on multi-index target functions of isotropic covariates. We characterize the optimal batch size minimizing the iteration time as a function of the hardness of the target, as characterized by the information exponents. We show that performing gradient updates with large batches $n_b \lesssim d^{\frac{\ell}{2}}$ minimizes the training time without changing the total sample complexity, where $\ell$ is the information exponent of the target to be learned \citep{arous2021online} and $d$ is the input dimension. However, larger batch sizes than $n_b \gg d^{\frac{\ell}{2}}$ are detrimental for improving the time complexity of SGD. We provably overcome this fundamental limitation via a different training protocol, \textit{Correlation loss SGD}, which suppresses the auto-correlation terms in the loss function. We show that one can track the training progress by a system of low-dimensional ordinary differential equations (ODEs). Finally, we validate our theoretical results with numerical experiments.

Online Learning and Information Exponents: On The Importance of Batch size, and Time/Complexity Tradeoffs

TL;DR

This work characterize the optimal batch size minimizing the iteration time as a function of the hardness of the target, as characterized by the information exponents, and shows that performing gradient updates with large batches minimizes the training time without changing the total sample complexity.

Abstract

We study the impact of the batch size on the iteration time of training two-layer neural networks with one-pass stochastic gradient descent (SGD) on multi-index target functions of isotropic covariates. We characterize the optimal batch size minimizing the iteration time as a function of the hardness of the target, as characterized by the information exponents. We show that performing gradient updates with large batches minimizes the training time without changing the total sample complexity, where is the information exponent of the target to be learned \citep{arous2021online} and is the input dimension. However, larger batch sizes than are detrimental for improving the time complexity of SGD. We provably overcome this fundamental limitation via a different training protocol, \textit{Correlation loss SGD}, which suppresses the auto-correlation terms in the loss function. We show that one can track the training progress by a system of low-dimensional ordinary differential equations (ODEs). Finally, we validate our theoretical results with numerical experiments.
Paper Structure (51 sections, 15 theorems, 139 equations, 10 figures, 1 table)

This paper contains 51 sections, 15 theorems, 139 equations, 10 figures, 1 table.

Key Result

Theorem 1

Consider the projected SGD algorithm with square loss (Eqs. eq:main:gd_update_weights, eq:main:projected_sgd), and suppose that Assumptions assump:poly_growth-assump:init hold. There exist absolute constants $c_\gamma, C_\gamma$ such that if then for large enough $d$ we have with probability $1-ce^{-c\log(n)^2}$

Figures (10)

  • Figure 1: Time / Batch size tradeoff for weak recovery: Phase diagram illustrating different SGD learning regimes as a function of the batch size exponent $\mu=\log_d n_b$ and weak recovery time exponent $\theta=\log_d T$. The analysis is dependent on the target's information exponent $\ell$, this particular plot is valid when $\ell\ge3$. Not correlating region: SGD is not able to achieve weak recovery. Self-interaction regime: SGD is not able to perform weak recovery, but Correlation loss SGD overcomes this limitation. Weak recovery region: SGD successfully achieve weak recovery. Note that it exists an optimal choice at batch size $n_b =O(d^{\ell/2})$ that minimizes the number of iterations needed by SGD, and another optimal point at $n_b=O(d^{\ell-1})$ for Correlation Loss. The critical line where $n_b=\Omega(d^{\ell-1})$ is not addressed by our formal . See details about the other two regions ( Polylog Regime and One-step regimedandi2023twolayer) in Appendix \ref{['app:sec:weakrecoveryglm']}.
  • Figure 2: Correlation Loss SGD weak recovery: Comparison between the performance of plain SGD and the Correlation Loss SGD, in different regions of the phase diagram, and for different sizes $d$. The plot shows the test error as a function of the optimization steps. Both the teacher and the student activation functions are fixed to $\sigma = h^\star =\text{He}_3$, so the information exponent is $\ell=3$. In all the three plots we vary the value of $\mu$, while $\delta = \mu - \ell/2$. Theorem \ref{['thm:main:no_yhat_weak_recovery']} predicts that the Correlation Loss SGD weakly recovers the target direction while SGD fails when $\delta<0$, in accordance to what is shown in the plot. Note that the numbers of steps needed for the target recovery drastically decrease when $\mu$ becomes large in accordance with Theorems \ref{['thm:main:sgd_weak_recovery']},\ref{['thm:main:no_yhat_weak_recovery']}.
  • Figure 3: Exact asymptotic description: Exact asymptotic characterization of the dynamics of two-layer networks trained with SGD as a function of the batch size ($n_b$) and the learning rate ($\gamma$). Left: Illustration of the different dynamical regimes in a compact phase diagram. Population flow region: The dynamics is equivalent to population gradient flow. Noise learning region: The high-dimensional noise terms dominate the dynamics. Saad&Solla line: The learning dynamics attains a plateau characterized by the noise variance saad.solla_1995_line. Dynamics not defined: The deterministic low-dimensional description of the eq. \ref{['eq:spherical_closed_form_ode']} is not valid. Right: The figure shows a comparison of numerical simulations (dots) and theoretical prediction (continuous lines) for three instances $(\delta,\mu)$ associated with different learning regimes (identified by the corresponding colors). For both theory and simulations, the test error is plotted as a function of SGD iterations. We consider a matching architectures problem, i.e. $\ h^\star = \sigma = \operatorname{erf}$ activation, and hidden units $p=2,k=2$.
  • Figure 4: Phase diagram for the learning rate: The plot identifies different learning behaviors of standard SGD and Correlation Loss SGD for different values of learning rate and batch size when considering randomly initialized networks, i.e. $m_0 = O(1/\sqrt{d})$.
  • Figure 5: learning single-index teacher with a wide student, when information exponent is $\ell=3$: $f^\star(\bm{x})=\text{He}_3(\bm{w}^\star\cdot\bm{x}), f(\bm{x})=1/4\sum_{i=1}^4\text{He}_3(\bm{w}_i\cdot\bm{x})$. Our theory extends to this case, showing that when $\mu>\ell/2$ only correlation loss can weakly recover the target. ($d=256, \gamma = \gamma_0\cdot p n_b d^{-\ell//2}$)
  • ...and 5 more figures

Theorems & Definitions (27)

  • Definition 1: Information Exponent arous2021online
  • Definition 2: Weak recovery
  • Theorem 1: Projected SGD weak recovery
  • Theorem 2: Correlation Loss SGD weak recovery
  • Remark 1
  • Proposition 1
  • Remark 2
  • Definition 3
  • Lemma 1
  • Lemma 2
  • ...and 17 more