Table of Contents
Fetching ...

Repetita Iuvant: Data Repetition Allows SGD to Learn High-Dimensional Multi-Index Functions

Luca Arnaboldi, Yatin Dandi, Florent Krzakala, Luca Pesce, Ludovic Stephan

TL;DR

This work analyzes how gradient-based learning of shallow two-layer networks can identify low-dimensional multi-index structure in high-dimensional data. By introducing data repetition—two gradient steps per sample—the authors show that almost all relevant directions can be learned in $O(d \log d)$ steps, with some hard cases like sparse parities, and they develop a Generative Exponent framework to capture this improved efficiency. They prove rigorous theorems for single-index and multi-index targets, demonstrating that polynomially transformable targets attain near-optimal sample and time complexities, and they illustrate hierarchical learning mechanisms when directions interact. The results suggest that reusing data in SGD-like training can substantially surpass prior one-pass SGD limits, offering practical insights for training strategies that leverage data repetition without heavy preprocessing.

Abstract

Neural networks can identify low-dimensional relevant structures within high-dimensional noisy data, yet our mathematical understanding of how they do so remains scarce. Here, we investigate the training dynamics of two-layer shallow neural networks trained with gradient-based algorithms, and discuss how they learn pertinent features in multi-index models, that is target functions with low-dimensional relevant directions. In the high-dimensional regime, where the input dimension $d$ diverges, we show that a simple modification of the idealized single-pass gradient descent training scenario, where data can now be repeated or iterated upon twice, drastically improves its computational efficiency. In particular, it surpasses the limitations previously believed to be dictated by the Information and Leap exponents associated with the target function to be learned. Our results highlight the ability of networks to learn relevant structures from data alone without any pre-processing. More precisely, we show that (almost) all directions are learned with at most $O(d \log d)$ steps. Among the exceptions is a set of hard functions that includes sparse parities. In the presence of coupling between directions, however, these can be learned sequentially through a hierarchical mechanism that generalizes the notion of staircase functions. Our results are proven by a rigorous study of the evolution of the relevant statistics for high-dimensional dynamics.

Repetita Iuvant: Data Repetition Allows SGD to Learn High-Dimensional Multi-Index Functions

TL;DR

This work analyzes how gradient-based learning of shallow two-layer networks can identify low-dimensional multi-index structure in high-dimensional data. By introducing data repetition—two gradient steps per sample—the authors show that almost all relevant directions can be learned in steps, with some hard cases like sparse parities, and they develop a Generative Exponent framework to capture this improved efficiency. They prove rigorous theorems for single-index and multi-index targets, demonstrating that polynomially transformable targets attain near-optimal sample and time complexities, and they illustrate hierarchical learning mechanisms when directions interact. The results suggest that reusing data in SGD-like training can substantially surpass prior one-pass SGD limits, offering practical insights for training strategies that leverage data repetition without heavy preprocessing.

Abstract

Neural networks can identify low-dimensional relevant structures within high-dimensional noisy data, yet our mathematical understanding of how they do so remains scarce. Here, we investigate the training dynamics of two-layer shallow neural networks trained with gradient-based algorithms, and discuss how they learn pertinent features in multi-index models, that is target functions with low-dimensional relevant directions. In the high-dimensional regime, where the input dimension diverges, we show that a simple modification of the idealized single-pass gradient descent training scenario, where data can now be repeated or iterated upon twice, drastically improves its computational efficiency. In particular, it surpasses the limitations previously believed to be dictated by the Information and Leap exponents associated with the target function to be learned. Our results highlight the ability of networks to learn relevant structures from data alone without any pre-processing. More precisely, we show that (almost) all directions are learned with at most steps. Among the exceptions is a set of hard functions that includes sparse parities. In the presence of coupling between directions, however, these can be learned sequentially through a hierarchical mechanism that generalizes the notion of staircase functions. Our results are proven by a rigorous study of the evolution of the relevant statistics for high-dimensional dynamics.
Paper Structure (35 sections, 18 theorems, 138 equations, 7 figures)

This paper contains 35 sections, 18 theorems, 138 equations, 7 figures.

Key Result

Theorem 1

There is a choice of hyperparameters such that if $h^\star$ is a polynomial function, Algorithm algo:optimizer_main_step achieves weak recovery in $O(d\log(d)^2)$ samples.

Figures (7)

  • Figure 1: Learning single-index targets -- Evolution of the Cosine Similarity attained by EgD (crosses) and SGD (dots) as a function of the normalized iteration time. (dashed horizontal line $\frac{1}{\sqrt{d}}$ is a visual guide to place random performance). (a) $(\ell,\ell^\star)=(1,1)$: both algorithms learn in linear time. (b) $(\ell,\ell^\star)=(2,2)$: both algorithms learn in $O(d\log d)$ time. (c) $(\ell,\ell^\star)=(3,1)$: EgD learns in linear time, while SGD requires $O(d^2)$ time. (d) $(\ell,\ell^\star)=(4,2)$: EgD learns in $O(d\log d)$ time, SGD requires $O(d^3)$ time (Details in App. \ref{['sec:app:implementation']}).
  • Figure 2: Learning multi-index targets -- Evolution of the maximum Cosine Similarities attained by EgD (crosses) and SGD (dots) as a function of the normalized iteration time. The Different directions $\{\mathbf w^\star_r\}_{r \in [k]}$ are identified by colors: $\mathbf w^\star_1$ (blue), $\mathbf w^\star_2$ (orange), $\mathbf w^\star_3$ (green). (dashed horizontal line $\frac{1}{\sqrt{d}}$ is a visual guide to place random performance). (a) $(\ell,\ell^\star)=(1,1)$: both algorithms learn the first direction in linear time, as well the second one in $O(d)$ steps using the staircase mechanism. (b) $(\ell,\ell^\star)=(2,2)$: both algorithms learn the two directions simultaneously in $O(d\log d)$ steps. (c) $(\ell_{\mathbf w^\star_2},\ell_{\mathbf w^\star_2}^\star)=(3,1)$: both algorithms learn $\mathbf w^\star_1$ in linear time, but only EgD learns $\mathbf w^\star_2$ in $O(d)$; SGD instead requires $O(d^2)$. (d) $(\ell,\ell^\star)=(3,2)$: EgD learns all 3 directions simultaneously in $O(d\log d)$ steps, while SGD needs $O(d^2)$ time. (Details in App. \ref{['sec:app:implementation']}).
  • Figure 3: Hierarchical learning -- Evolution of the maximum Cosine Similarities with different target directions $\{\mathbf w^\star_r\}_{r \in [k]}$ are identified by different colors: $\mathbf w^\star_1$ (blue), $\mathbf w^\star_2$ (orange), $\mathbf w^\star_3$ (green) as a function of the normalized iteration time. (a): both the algorithms learn $\mathbf w^\star_1$ in linear time, but only EgD can also learn $\mathbf w^\star_2$ in $O(d)$ via a staircase mechanism; SGD requires another $O(d^2)$ steps to take advantage of the staircase and learn $\mathbf w^\star_2$. (b): EgD performance for the target function $f^\star_{\mathrm{stair}}$ (dots) and $f^\star_{\mathrm{sign}}$(squares); we plot in separate subplots the overlap with different directions to highlight the presence of hierarchical learning mechanism. Learning the first target direction $\mathbf w^\star_1$ triggers the hierarchical mechanism and EgD is able to weakly recover the full target subspace $\mathrm{Span}(\mathbf w^\star_1, \mathbf w^\star_2, \mathbf w^\star_3)$ in $O(d \log d)$, while this does not happen when removing $\mathrm{He}_2(x_1)$ from the target. See App. \ref{['sec:app:implementation']} for additional details.
  • Figure 4: Hierarchical learning -- Evolution of the maximum Cosine Similarities attained by EgD with different target directions $\{\mathbf w^\star_r\}_{r \in [k]}$ are identified by different colors: $\mathbf w^\star_1$ (blue), $\mathbf w^\star_2$ (orange), $\mathbf w^\star_3$ (green) as a function of the normalized iteration time. The target function is $f^\star_{\mathrm{sq-stair}}(\mathbf z)= \mathrm{He}_4(z_1) + \mathrm{sign}(z_1z_2z_3)$, that is an SQ-staircase (or grand staircase), not a CSQ-staircase. EgD learn $\mathbf w^\star_1$ in $O(d\log d)$ steps and use the information to learn the other two directions in another $O(d\log d)$ steps. SGD (not showed in the plot) cannot take advantage of the staircase mechanism since $\mathrm{He}_4(z_1)$ and $\mathrm{sign}(z_1z_2z_3)$ have information exponent $\ell = 4$ and $\ell = 3$ respectively. (See App. \ref{['sec:app:implementation']}).
  • Figure 5: example of two different algorithm with data repetitions learning an hard single-index target $h^\star(\mathbf z)=\mathrm{He}_3(z_1)$, $\ell =3, \ell^\star=1$. Left: SAM, with $\rho_0=0.1,\gamma_0=0.01$. Right: 2-Lookahed with $\gamma_0=0.1$. See details in App. \ref{['sec:app:implementation']}.
  • ...and 2 more figures

Theorems & Definitions (35)

  • Definition 1: Weak recovery
  • Theorem 1: Informal
  • Definition 2: Polynomial Generative Information exponent
  • Theorem 2
  • Lemma 1
  • Theorem 3
  • Definition 3
  • Theorem 4
  • Proposition 1
  • proof
  • ...and 25 more