Learning Over-Parametrized Two-Layer ReLU Neural Networks beyond NTK
Yuanzhi Li, Tengyu Ma, Hongyang R. Zhang
TL;DR
This work analyzes gradient descent for an over-parameterized two-layer ReLU network trained on Gaussian inputs to learn a ground-truth network f^*(x)=a^T|W^*x| with orthonormal W^*. It develops a two-stage infinite-width analysis that decomposes the population loss into an infinite sum of tensor Decompositions, showing Stage 1 reduces 0th and 2nd order terms before higher-order terms enable recovery of the ground-truth weights, followed by a finite-width reduction that preserves these dynamics with polynomial-width samples. The authors prove a separation from polynomial-size kernel methods (NTK) by showing population loss L̂(Ŵ)=O(1/d^{1+Q}) while kernels incur Ω(1/d); simulations corroborate stage-wise convergence and highlight the necessity of sufficient over-parameterization. The results advance understanding beyond NTK by leveraging higher-order tensor decompositions to achieve tighter generalization and successful learning under gradient dynamics.
Abstract
We consider the dynamic of gradient descent for learning a two-layer neural network. We assume the input $x\in\mathbb{R}^d$ is drawn from a Gaussian distribution and the label of $x$ satisfies $f^{\star}(x) = a^{\top}|W^{\star}x|$, where $a\in\mathbb{R}^d$ is a nonnegative vector and $W^{\star} \in\mathbb{R}^{d\times d}$ is an orthonormal matrix. We show that an over-parametrized two-layer neural network with ReLU activation, trained by gradient descent from random initialization, can provably learn the ground truth network with population loss at most $o(1/d)$ in polynomial time with polynomial samples. On the other hand, we prove that any kernel method, including Neural Tangent Kernel, with a polynomial number of samples in $d$, has population loss at least $Ω(1 / d)$.
