Table of Contents
Fetching ...

Learning Over-Parametrized Two-Layer ReLU Neural Networks beyond NTK

Yuanzhi Li, Tengyu Ma, Hongyang R. Zhang

TL;DR

This work analyzes gradient descent for an over-parameterized two-layer ReLU network trained on Gaussian inputs to learn a ground-truth network f^*(x)=a^T|W^*x| with orthonormal W^*. It develops a two-stage infinite-width analysis that decomposes the population loss into an infinite sum of tensor Decompositions, showing Stage 1 reduces 0th and 2nd order terms before higher-order terms enable recovery of the ground-truth weights, followed by a finite-width reduction that preserves these dynamics with polynomial-width samples. The authors prove a separation from polynomial-size kernel methods (NTK) by showing population loss L̂(Ŵ)=O(1/d^{1+Q}) while kernels incur Ω(1/d); simulations corroborate stage-wise convergence and highlight the necessity of sufficient over-parameterization. The results advance understanding beyond NTK by leveraging higher-order tensor decompositions to achieve tighter generalization and successful learning under gradient dynamics.

Abstract

We consider the dynamic of gradient descent for learning a two-layer neural network. We assume the input $x\in\mathbb{R}^d$ is drawn from a Gaussian distribution and the label of $x$ satisfies $f^{\star}(x) = a^{\top}|W^{\star}x|$, where $a\in\mathbb{R}^d$ is a nonnegative vector and $W^{\star} \in\mathbb{R}^{d\times d}$ is an orthonormal matrix. We show that an over-parametrized two-layer neural network with ReLU activation, trained by gradient descent from random initialization, can provably learn the ground truth network with population loss at most $o(1/d)$ in polynomial time with polynomial samples. On the other hand, we prove that any kernel method, including Neural Tangent Kernel, with a polynomial number of samples in $d$, has population loss at least $Ω(1 / d)$.

Learning Over-Parametrized Two-Layer ReLU Neural Networks beyond NTK

TL;DR

This work analyzes gradient descent for an over-parameterized two-layer ReLU network trained on Gaussian inputs to learn a ground-truth network f^*(x)=a^T|W^*x| with orthonormal W^*. It develops a two-stage infinite-width analysis that decomposes the population loss into an infinite sum of tensor Decompositions, showing Stage 1 reduces 0th and 2nd order terms before higher-order terms enable recovery of the ground-truth weights, followed by a finite-width reduction that preserves these dynamics with polynomial-width samples. The authors prove a separation from polynomial-size kernel methods (NTK) by showing population loss L̂(Ŵ)=O(1/d^{1+Q}) while kernels incur Ω(1/d); simulations corroborate stage-wise convergence and highlight the necessity of sufficient over-parameterization. The results advance understanding beyond NTK by leveraging higher-order tensor decompositions to achieve tighter generalization and successful learning under gradient dynamics.

Abstract

We consider the dynamic of gradient descent for learning a two-layer neural network. We assume the input is drawn from a Gaussian distribution and the label of satisfies , where is a nonnegative vector and is an orthonormal matrix. We show that an over-parametrized two-layer neural network with ReLU activation, trained by gradient descent from random initialization, can provably learn the ground truth network with population loss at most in polynomial time with polynomial samples. On the other hand, we prove that any kernel method, including Neural Tangent Kernel, with a polynomial number of samples in , has population loss at least .

Paper Structure

This paper contains 48 sections, 25 theorems, 309 equations, 2 figures, 1 algorithm.

Key Result

Theorem 1.1

Let $\mathcal{Z}$ be a training dataset with $N = \mathop{\mathrm{poly}}\nolimits_{\kappa}(d)$ samples generated by the model described above. Let $\mathop{\mathrm{poly}}\nolimits(d)$ denote a polynomial of $d$ and $\mathop{\mathrm{poly}}\nolimits_\kappa(d)$ denote a polynomial whose degree may dep

Figures (2)

  • Figure 1: Illustrating the convergence of each tensor during the gradient descent dynamic using absolute value activations.
  • Figure 2: For properly parametrized gradient descent, the 4th and 6th order tensors get stuck using absolute value activations.

Theorems & Definitions (91)

  • Theorem 1.1: Main result
  • Theorem 1.2: Lower bound
  • Claim 2.1
  • Definition 3.1: Conditional-symmetry
  • Theorem 3.1: Infinite-width case
  • Claim 3.1
  • Claim 3.2
  • Definition A.1: Truncated neuron space
  • Definition A.2
  • Proposition A.1: Inductive hypothesis $\mathcal{H}_1$ for Stage 1
  • ...and 81 more