Table of Contents
Fetching ...

Neural network learns low-dimensional polynomials with SGD near the information-theoretic limit

Jason D. Lee, Kazusato Oko, Taiji Suzuki, Denny Wu

TL;DR

This work tackles the problem of learning a single-index model $f_*(\boldsymbol{x})=\sigma_*(\langle \boldsymbol{x},\boldsymbol{\theta}\rangle)$ under Gaussian inputs, addressing the gap between information-theoretic limits and the computational complexity of gradient-based methods. It shows that a two-layer neural network trained by SGD with reused minibatches can achieve vanishing generalization error with sample complexity $n=\tilde{\Theta}_d\big(d^{(p_*-1)\vee 1}\big)$, where $p_*$ is the generative exponent, and that for polynomial $\sigma_*$ this reduces to near-linear samples $n=\tilde{O}(d)$. The key mechanism is a monomial transformation of the labels, extracted from SGD updates, which lowers the information exponent and enables a two-phase training strategy: Phase I (weak then strong recovery of the first-layer weights) and Phase II (ridge-regularized second-layer fitting to the unknown link). The results indicate that SGD with batch reuse can implement a full SQ algorithm beyond correlational queries, achieving near-optimal statistical efficiency for a broad class of single-index targets and suggesting practical avenues for representation learning under low-dimensional structure. Overall, the paper narrows the gap between statistical optimality and computational tractability for low-dimensional targets in high-dimensional Gaussian settings, with implications for understanding the role of data reuse and higher-order information in gradient-based learning.

Abstract

We study the problem of gradient descent learning of a single-index target function $f_*(\boldsymbol{x}) = \textstyleσ_*\left(\langle\boldsymbol{x},\boldsymbolθ\rangle\right)$ under isotropic Gaussian data in $\mathbb{R}^d$, where the unknown link function $σ_*:\mathbb{R}\to\mathbb{R}$ has information exponent $p$ (defined as the lowest degree in the Hermite expansion). Prior works showed that gradient-based training of neural networks can learn this target with $n\gtrsim d^{Θ(p)}$ samples, and such complexity is predicted to be necessary by the correlational statistical query lower bound. Surprisingly, we prove that a two-layer neural network optimized by an SGD-based algorithm (on the squared loss) learns $f_*$ with a complexity that is not governed by the information exponent. Specifically, for arbitrary polynomial single-index models, we establish a sample and runtime complexity of $n \simeq T = Θ(d\!\cdot\! \mathrm{polylog} d)$, where $Θ(\cdot)$ hides a constant only depending on the degree of $σ_*$; this dimension dependence matches the information theoretic limit up to polylogarithmic factors. More generally, we show that $n\gtrsim d^{(p_*-1)\vee 1}$ samples are sufficient to achieve low generalization error, where $p_* \le p$ is the \textit{generative exponent} of the link function. Core to our analysis is the reuse of minibatch in the gradient computation, which gives rise to higher-order information beyond correlational queries.

Neural network learns low-dimensional polynomials with SGD near the information-theoretic limit

TL;DR

This work tackles the problem of learning a single-index model under Gaussian inputs, addressing the gap between information-theoretic limits and the computational complexity of gradient-based methods. It shows that a two-layer neural network trained by SGD with reused minibatches can achieve vanishing generalization error with sample complexity , where is the generative exponent, and that for polynomial this reduces to near-linear samples . The key mechanism is a monomial transformation of the labels, extracted from SGD updates, which lowers the information exponent and enables a two-phase training strategy: Phase I (weak then strong recovery of the first-layer weights) and Phase II (ridge-regularized second-layer fitting to the unknown link). The results indicate that SGD with batch reuse can implement a full SQ algorithm beyond correlational queries, achieving near-optimal statistical efficiency for a broad class of single-index targets and suggesting practical avenues for representation learning under low-dimensional structure. Overall, the paper narrows the gap between statistical optimality and computational tractability for low-dimensional targets in high-dimensional Gaussian settings, with implications for understanding the role of data reuse and higher-order information in gradient-based learning.

Abstract

We study the problem of gradient descent learning of a single-index target function under isotropic Gaussian data in , where the unknown link function has information exponent (defined as the lowest degree in the Hermite expansion). Prior works showed that gradient-based training of neural networks can learn this target with samples, and such complexity is predicted to be necessary by the correlational statistical query lower bound. Surprisingly, we prove that a two-layer neural network optimized by an SGD-based algorithm (on the squared loss) learns with a complexity that is not governed by the information exponent. Specifically, for arbitrary polynomial single-index models, we establish a sample and runtime complexity of , where hides a constant only depending on the degree of ; this dimension dependence matches the information theoretic limit up to polylogarithmic factors. More generally, we show that samples are sufficient to achieve low generalization error, where is the \textit{generative exponent} of the link function. Core to our analysis is the reuse of minibatch in the gradient computation, which gives rise to higher-order information beyond correlational queries.
Paper Structure (33 sections, 11 theorems, 27 equations, 2 figures, 1 algorithm)

This paper contains 33 sections, 11 theorems, 27 equations, 2 figures, 1 algorithm.

Key Result

Theorem 1

A shallow NN with $N=\tilde{\Theta}_d(1)$ neurons can learn arbitrary single-index models up to small population loss: $\mathbb{E}_{\boldsymbol{x}}[(f_{\boldsymbol{\Theta}}(\boldsymbol{x}) - f_*(\boldsymbol{x}))^2] = o_{d,\mathbb{P}}(1)$, if we employ an SGD-based algorithm (with reused training dat

Figures (2)

  • Figure 1: We train a ReLU NN \ref{['eq:student']} with $N=1024$ neurons using SGD (squared loss) with step size $\eta=1/d$ to learn a single-index target $f_*(\boldsymbol{x}) = \mathsf{He}_3(\langle\boldsymbol{x},\boldsymbol{\theta}\rangle)$; heatmaps are values averaged over 10 runs. $(a)$ online SGD with batch size $B=8$; $(b)$ GD on the same batch of size $n$ for $T=2^{14}$ steps. For online SGD we only report weak recovery (i.e., averaged overlap between neuron $\boldsymbol{w}$ and target $\boldsymbol{\theta}$) since the test error does not drop.
  • Figure 2: Complexity of learning single-index model where the link function $\sigma_*$ is a degree-$q$ polynomial with information exponent $p$. For the CSQ lower bound, we translate the tolerance to sample complexity using the i.i.d. concentration heuristic $\tau\approx n^{-1/2}$. We restrict ourselves to algorithms using polynomial compute; this excludes the sphere-covering procedure in damian2024computational or exponential-width neural network in bach2017breakingtakakura2024mean.

Theorems & Definitions (13)

  • Theorem : informal
  • Definition 1: Information exponent
  • Definition 2: Generative exponent
  • Lemma 3
  • Theorem 1
  • Lemma 4
  • Theorem 2
  • Proposition 5
  • Proposition 6
  • Lemma 7
  • ...and 3 more