Neural Networks Learn Generic Multi-Index Models Near Information-Theoretic Limit
Bohan Zhang, Zihao Wang, Hengyu Fu, Jason D. Lee
TL;DR
The paper addresses why neural networks can efficiently learn high-dimensional features by analyzing gradient-descent learning on a Gaussian multi-index model $f^*(\boldsymbol{x}) = g^*(\boldsymbol{U}\boldsymbol{x})$ with a two-layer network. It reveals a two-stage dynamics where the first layer undergoes a power-iteration-like process on a preconditioned Hessian, requiring an intermediate stopping time to recover the full hidden subspace, followed by a second stage that completes end-to-end learning. Under mild, generic assumptions and for polynomial link functions, the authors prove near information-theoretic optimality: $\tilde{O}(d)$ samples and $\tilde{O}(d^2)$ time to achieve $o_d(1)$ test error, with $m=\tilde{O}(1)$ and $T_2=\tilde{O}(1)$. They also demonstrate robustness to a broad class of activations and losses and provide numerical illustrations showing the phase transitions and the importance of loss choice. This work significantly advances understanding of representation learning by showing that standard gradient-based training can attain near-optimal sample efficiency for generic multi-index models, with implications for both theory and practice.
Abstract
In deep learning, a central issue is to understand how neural networks efficiently learn high-dimensional features. To this end, we explore the gradient descent learning of a general Gaussian Multi-index model $f(\boldsymbol{x})=g(\boldsymbol{U}\boldsymbol{x})$ with hidden subspace $\boldsymbol{U}\in \mathbb{R}^{r\times d}$, which is the canonical setup to study representation learning. We prove that under generic non-degenerate assumptions on the link function, a standard two-layer neural network trained via layer-wise gradient descent can agnostically learn the target with $o_d(1)$ test error using $\widetilde{\mathcal{O}}(d)$ samples and $\widetilde{\mathcal{O}}(d^2)$ time. The sample and time complexity both align with the information-theoretic limit up to leading order and are therefore optimal. During the first stage of gradient descent learning, the proof proceeds via showing that the inner weights can perform a power-iteration process. This process implicitly mimics a spectral start for the whole span of the hidden subspace and eventually eliminates finite-sample noise and recovers this span. It surprisingly indicates that optimal results can only be achieved if the first layer is trained for more than $\mathcal{O}(1)$ steps. This work demonstrates the ability of neural networks to effectively learn hierarchical functions with respect to both sample and time efficiency.
