Learning quadratic neural networks in high dimensions: SGD dynamics and scaling laws
Gérard Ben Arous, Murat A. Erdogdu, Nuri Mert Vural, Denny Wu
TL;DR
This work analyzes gradient-based training of a two-layer neural network with a quadratic activation on isotropic Gaussian data in the high-dimensional, extensive-width regime. It delivers sharp scaling laws for both the population gradient flow and online SGD, showing how alignment to the teacher subspace and prediction risk depend on the tail exponent $\alpha$, the widths $r$ and $r_s$, and the optimization time. The authors introduce a matrix Riccati ODE framework to describe coupled feature-learning dynamics and develop operator-norm based discretization techniques to bound discrete SGD, obtaining nearly information-theoretic sample complexities in the isotropic case and revealing additive, phase-like learning behavior across many latent directions. The results demonstrate how anisotropic, power-law signal strengths shape convergence and sample efficiency, offering a principled view of feature learning in extensive-width nonlinear networks and guiding future extensions to more general activations. Overall, the paper provides rigorous scaling laws that connect optimization time, model width, and data size to the learned subspace and residual risk in quadratic neural nets.
Abstract
We study the optimization and sample complexity of gradient-based training of a two-layer neural network with quadratic activation function in the high-dimensional regime, where the data is generated as $y \propto \sum_{j=1}^{r}λ_j σ\left(\langle \boldsymbol{θ_j}, \boldsymbol{x}\rangle\right), \boldsymbol{x} \sim N(0,\boldsymbol{I}_d)$, $σ$ is the 2nd Hermite polynomial, and $\lbrace\boldsymbolθ_j \rbrace_{j=1}^{r} \subset \mathbb{R}^d$ are orthonormal signal directions. We consider the extensive-width regime $r \asymp d^β$ for $β\in [0, 1)$, and assume a power-law decay on the (non-negative) second-layer coefficients $λ_j\asymp j^{-α}$ for $α\geq 0$. We present a sharp analysis of the SGD dynamics in the feature learning regime, for both the population limit and the finite-sample (online) discretization, and derive scaling laws for the prediction risk that highlight the power-law dependencies on the optimization time, sample size, and model width. Our analysis combines a precise characterization of the associated matrix Riccati differential equation with novel matrix monotonicity arguments to establish convergence guarantees for the infinite-dimensional effective dynamics.
