Table of Contents
Fetching ...

The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents

Yatin Dandi, Emanuele Troiani, Luca Arnaboldi, Luca Pesce, Lenka Zdeborová, Florent Krzakala

TL;DR

It is shown that upon re-using batches, the network achieves in just two time steps an overlap with the target subspace even for functions not satisfying the staircase property, which characterize the (broad) class of functions efficiently learned in finite time.

Abstract

We investigate the training dynamics of two-layer neural networks when learning multi-index target functions. We focus on multi-pass gradient descent (GD) that reuses the batches multiple times and show that it significantly changes the conclusion about which functions are learnable compared to single-pass gradient descent. In particular, multi-pass GD with finite stepsize is found to overcome the limitations of gradient flow and single-pass GD given by the information exponent (Ben Arous et al., 2021) and leap exponent (Abbe et al., 2023) of the target function. We show that upon re-using batches, the network achieves in just two time steps an overlap with the target subspace even for functions not satisfying the staircase property (Abbe et al., 2021). We characterize the (broad) class of functions efficiently learned in finite time. The proof of our results is based on the analysis of the Dynamical Mean-Field Theory (DMFT). We further provide a closed-form description of the dynamical process of the low-dimensional projections of the weights, and numerical experiments illustrating the theory.

The Benefits of Reusing Batches for Gradient Descent in Two-Layer Networks: Breaking the Curse of Information and Leap Exponents

TL;DR

It is shown that upon re-using batches, the network achieves in just two time steps an overlap with the target subspace even for functions not satisfying the staircase property, which characterize the (broad) class of functions efficiently learned in finite time.

Abstract

We investigate the training dynamics of two-layer neural networks when learning multi-index target functions. We focus on multi-pass gradient descent (GD) that reuses the batches multiple times and show that it significantly changes the conclusion about which functions are learnable compared to single-pass gradient descent. In particular, multi-pass GD with finite stepsize is found to overcome the limitations of gradient flow and single-pass GD given by the information exponent (Ben Arous et al., 2021) and leap exponent (Abbe et al., 2023) of the target function. We show that upon re-using batches, the network achieves in just two time steps an overlap with the target subspace even for functions not satisfying the staircase property (Abbe et al., 2021). We characterize the (broad) class of functions efficiently learned in finite time. The proof of our results is based on the analysis of the Dynamical Mean-Field Theory (DMFT). We further provide a closed-form description of the dynamical process of the low-dimensional projections of the weights, and numerical experiments illustrating the theory.
Paper Structure (34 sections, 11 theorems, 106 equations, 5 figures)

This paper contains 34 sections, 11 theorems, 106 equations, 5 figures.

Key Result

Theorem 3.2

Suppose that $n/d=\alpha>0$. Let $\mathbf{v}^\star\in P^\star_\perp$ denote an arbitrary direction in the orthogonal complement of the subspace $P^\star$ defined in definition def:two_step_hard with norm $\sqrt{d}$ and a fixed representation in the basis $W^\star$. Suppose further that the activatio with high probability as $n,d \rightarrow \infty$. Furthermore, for large enough $p$, $\mathbf{W}^{

Figures (5)

  • Figure 1: One-pass and multi-pass GD for single-index models -- The overlap$\left|\frac{\langle \mathbf w^\star, \hat{\mathbf{w}}\rangle }{d}\right|$ between the learned weight and the target/teacher direction, is plotted as a function of the iteration time of both single-pass (red) and multi-pass (blue) GD. Continuous lines are given theory, dots are simulations. Left: Easy finite-T learnable single-index target$g^\star \!=\! \tanh$: both one-pass and multi-pass GD obtain positive correlation after a finite number of iterations as the information exponent of the target is $\ell\!=\!1$. Center: Multi-pass finite-T learnable single-index target: $g^\star \!=\! \mathrm{He}_3$. Multi-pass GD achieves a non-zero correlation in just two steps, but the one-pass algorithm learns nothing. Right: Finite-time nonlearnable single-index targets$g^\star \!=\! \mathrm{He}_4$; the target function is even and thus, as stated in Thm. \ref{['thm:main:2step_learning']}, breaking this symmetry is hard in finite number of steps, resulting in a vanishing correlation with the teacher direction $\mathbf w^\star$ for both algorithms in any finite time. (Simulation are averaged over $32$ runs, $d=5000$, with $\sigma = \rm relu$, $n=3d$, $p=1$, $\eta = 0.1$).
  • Figure 2: One-pass and multi-pass GD for multi-index models -- The overlaps between the student weights along the first direction learned, namely $\mathbf{C}[f^\star]$, and its orthogonal, is plotted versus the number of iterations for three different classes of functions. Left: Easy finite-T learnable multi-index target both the algorithms learn all the relevant directions when an "easy" function is used as a target (here ($p=8$)). Center: Multi-pass finite-T learnable multi-index target both the algorithms learn the first Hermite direction $\mathbf{C}[f^\star]$ but only multi-pass SGD achieve a non-null correlation in the orthogonal. This illustrates how reusing samples allows us to surpass the staircase limitation of single-pass approaches ($p=2$). Right: Finite-time non-learnable multi-index target neither of the two algorithm can learn $\mathbf{C}[f^\star]^\bot$ with this target ($p=2$). (Simulation are averaged over $32$ runs, $d\!=\!5000$, with $\sigma \!=\! \rm relu$, $n=3d$, $\eta \!=\! 0.1$).
  • Figure 3: An illustration of a hard, non-even target $f^\star(\mathbf z) \!=\! z_1 z_2 z_3 \!+\! \mathrm{He}_3(z_4)$ being learned by a student with $p=4$ hidden units. We can see that, even when reusing the batch, the teacher can only learn the direction associated with $z_4$, while keeping a zero overlap otherwise. The continuous lines are from the DMFT numerical integration, the dots are simulations with $d\!=\!10000$. In the legend the overlap with the $n$-th direction is the projection of the student weights in the subspace associated with $z_n$. For this figure we have $\sigma \!=\! \rm relu$, $n=5d$, $\eta \!=\! 0.2$.
  • Figure 4: Comparison of theory and experiments for Gradient Descent on the target $z_1z_2z_3 + \mathrm{He}_3(z_4)$. Each gradient step uses a mini-batche of $n/5$ samples. On the left we use the data sequentially, on the right we sample the batch from the dataset with replacement. The continuous lines are from the DMFT numerical integration, the dots are simulations with $d\!=\!10000$ averaged over $32$ realisations. In the legend the overlap with the $n$-th direction is the projection of the student weights in the subspace associated with $z_n$. For this figure we have $\sigma \!=\! \rm relu$, $n=5d$, $\eta \!=\! 0.2$.
  • Figure 5: Experiments for Gradient Descent on the target $z_1z_2z_3 + \mathrm{He}_3(z_4)$. We use minibatches with $1$ sample each. On the left we use the data sequentially, on the right we sample the data point from the dataset with replacement. The dots are simulations with $d\!=\!10000$ averaged over $32$ realisations. In the legend the overlap with the $n$-th direction denotes the projection of the student weights in the subspace associated with $z_n$. For this figure we have $\sigma \!=\! \rm relu$, $n=5d$, $\eta \!=\! 0.2$.

Theorems & Definitions (22)

  • Definition 3.1
  • Theorem 3.2
  • Definition 3.3
  • Definition 3.4
  • Proposition 3.5
  • Definition 3.6: Information Exponent
  • Remark 3.7
  • Theorem A.1: Corollary of Theorem 3.2 in gerbelot2023rigorous
  • Theorem A.2
  • proof
  • ...and 12 more