Table of Contents
Fetching ...

Neural Networks Learn Generic Multi-Index Models Near Information-Theoretic Limit

Bohan Zhang, Zihao Wang, Hengyu Fu, Jason D. Lee

TL;DR

The paper addresses why neural networks can efficiently learn high-dimensional features by analyzing gradient-descent learning on a Gaussian multi-index model $f^*(\boldsymbol{x}) = g^*(\boldsymbol{U}\boldsymbol{x})$ with a two-layer network. It reveals a two-stage dynamics where the first layer undergoes a power-iteration-like process on a preconditioned Hessian, requiring an intermediate stopping time to recover the full hidden subspace, followed by a second stage that completes end-to-end learning. Under mild, generic assumptions and for polynomial link functions, the authors prove near information-theoretic optimality: $\tilde{O}(d)$ samples and $\tilde{O}(d^2)$ time to achieve $o_d(1)$ test error, with $m=\tilde{O}(1)$ and $T_2=\tilde{O}(1)$. They also demonstrate robustness to a broad class of activations and losses and provide numerical illustrations showing the phase transitions and the importance of loss choice. This work significantly advances understanding of representation learning by showing that standard gradient-based training can attain near-optimal sample efficiency for generic multi-index models, with implications for both theory and practice.

Abstract

In deep learning, a central issue is to understand how neural networks efficiently learn high-dimensional features. To this end, we explore the gradient descent learning of a general Gaussian Multi-index model $f(\boldsymbol{x})=g(\boldsymbol{U}\boldsymbol{x})$ with hidden subspace $\boldsymbol{U}\in \mathbb{R}^{r\times d}$, which is the canonical setup to study representation learning. We prove that under generic non-degenerate assumptions on the link function, a standard two-layer neural network trained via layer-wise gradient descent can agnostically learn the target with $o_d(1)$ test error using $\widetilde{\mathcal{O}}(d)$ samples and $\widetilde{\mathcal{O}}(d^2)$ time. The sample and time complexity both align with the information-theoretic limit up to leading order and are therefore optimal. During the first stage of gradient descent learning, the proof proceeds via showing that the inner weights can perform a power-iteration process. This process implicitly mimics a spectral start for the whole span of the hidden subspace and eventually eliminates finite-sample noise and recovers this span. It surprisingly indicates that optimal results can only be achieved if the first layer is trained for more than $\mathcal{O}(1)$ steps. This work demonstrates the ability of neural networks to effectively learn hierarchical functions with respect to both sample and time efficiency.

Neural Networks Learn Generic Multi-Index Models Near Information-Theoretic Limit

TL;DR

The paper addresses why neural networks can efficiently learn high-dimensional features by analyzing gradient-descent learning on a Gaussian multi-index model with a two-layer network. It reveals a two-stage dynamics where the first layer undergoes a power-iteration-like process on a preconditioned Hessian, requiring an intermediate stopping time to recover the full hidden subspace, followed by a second stage that completes end-to-end learning. Under mild, generic assumptions and for polynomial link functions, the authors prove near information-theoretic optimality: samples and time to achieve test error, with and . They also demonstrate robustness to a broad class of activations and losses and provide numerical illustrations showing the phase transitions and the importance of loss choice. This work significantly advances understanding of representation learning by showing that standard gradient-based training can attain near-optimal sample efficiency for generic multi-index models, with implications for both theory and practice.

Abstract

In deep learning, a central issue is to understand how neural networks efficiently learn high-dimensional features. To this end, we explore the gradient descent learning of a general Gaussian Multi-index model with hidden subspace , which is the canonical setup to study representation learning. We prove that under generic non-degenerate assumptions on the link function, a standard two-layer neural network trained via layer-wise gradient descent can agnostically learn the target with test error using samples and time. The sample and time complexity both align with the information-theoretic limit up to leading order and are therefore optimal. During the first stage of gradient descent learning, the proof proceeds via showing that the inner weights can perform a power-iteration process. This process implicitly mimics a spectral start for the whole span of the hidden subspace and eventually eliminates finite-sample noise and recovers this span. It surprisingly indicates that optimal results can only be achieved if the first layer is trained for more than steps. This work demonstrates the ability of neural networks to effectively learn hierarchical functions with respect to both sample and time efficiency.

Paper Structure

This paper contains 63 sections, 69 theorems, 441 equations, 2 figures, 1 algorithm.

Key Result

Theorem 1

Under assumptions assumption_link_first, assumption_activation, assumption_loss, assumption_ell, assumption_cov, def_activation, ass:epsilon0, or under the additional assumption that $\beta=0$ and assumptions assumption_link_first, assumption_activation, assumption_loss_relaxed, assumption_ell, assu there exists a proper choice of the hyperparameters and second stage training time $\mathrm{T}_2$ s

Figures (2)

  • Figure 1: Minimal sample size exponent $\alpha$ (where $n = d^\alpha$) required to achieve test error $\leq \epsilon$ as a function of dimension $d$. Three thresholds are shown: $\epsilon = 1.0$ (light blue), $\epsilon = 0.1$ (medium blue), and $\epsilon = 0.01$ (dark blue). More stringent error thresholds require larger $\alpha$ (more samples), but all curves exhibit similar downward trends.
  • Figure 2: MSE and Huber loss for learning $\mathrm{h}_4(x_1) + \mathrm{h}_4(x_2)$ across different dimensions $d \in \{200, 300, 500\}$. The $x$-axis shows $n/d$, and the $y$-axis shows the cos similarity $\cos(\text{best})$. Lines show median with error bars indicating $30$th to $70$th percentiles over $10$ random seeds. Blue shades represent MSE and red/orange shades represent Huber loss.

Theorems & Definitions (126)

  • Theorem 1
  • Corollary 3.1
  • Lemma A.1: Total Perturbation Error
  • Lemma A.2: Operator Norm Bound for ${\boldsymbol H}_{\ell}$
  • Lemma A.3
  • Lemma A.4: Operator Norm Bound for ${\boldsymbol Q}^{(t)}$
  • Corollary A.5
  • Lemma A.6: Norm Bound after $\mathrm{T}_1$ Steps
  • Lemma A.7
  • Lemma A.8
  • ...and 116 more