Table of Contents
Fetching ...

Learning quadratic neural networks in high dimensions: SGD dynamics and scaling laws

Gérard Ben Arous, Murat A. Erdogdu, Nuri Mert Vural, Denny Wu

TL;DR

This work analyzes gradient-based training of a two-layer neural network with a quadratic activation on isotropic Gaussian data in the high-dimensional, extensive-width regime. It delivers sharp scaling laws for both the population gradient flow and online SGD, showing how alignment to the teacher subspace and prediction risk depend on the tail exponent $\alpha$, the widths $r$ and $r_s$, and the optimization time. The authors introduce a matrix Riccati ODE framework to describe coupled feature-learning dynamics and develop operator-norm based discretization techniques to bound discrete SGD, obtaining nearly information-theoretic sample complexities in the isotropic case and revealing additive, phase-like learning behavior across many latent directions. The results demonstrate how anisotropic, power-law signal strengths shape convergence and sample efficiency, offering a principled view of feature learning in extensive-width nonlinear networks and guiding future extensions to more general activations. Overall, the paper provides rigorous scaling laws that connect optimization time, model width, and data size to the learned subspace and residual risk in quadratic neural nets.

Abstract

We study the optimization and sample complexity of gradient-based training of a two-layer neural network with quadratic activation function in the high-dimensional regime, where the data is generated as $y \propto \sum_{j=1}^{r}λ_j σ\left(\langle \boldsymbol{θ_j}, \boldsymbol{x}\rangle\right), \boldsymbol{x} \sim N(0,\boldsymbol{I}_d)$, $σ$ is the 2nd Hermite polynomial, and $\lbrace\boldsymbolθ_j \rbrace_{j=1}^{r} \subset \mathbb{R}^d$ are orthonormal signal directions. We consider the extensive-width regime $r \asymp d^β$ for $β\in [0, 1)$, and assume a power-law decay on the (non-negative) second-layer coefficients $λ_j\asymp j^{-α}$ for $α\geq 0$. We present a sharp analysis of the SGD dynamics in the feature learning regime, for both the population limit and the finite-sample (online) discretization, and derive scaling laws for the prediction risk that highlight the power-law dependencies on the optimization time, sample size, and model width. Our analysis combines a precise characterization of the associated matrix Riccati differential equation with novel matrix monotonicity arguments to establish convergence guarantees for the infinite-dimensional effective dynamics.

Learning quadratic neural networks in high dimensions: SGD dynamics and scaling laws

TL;DR

This work analyzes gradient-based training of a two-layer neural network with a quadratic activation on isotropic Gaussian data in the high-dimensional, extensive-width regime. It delivers sharp scaling laws for both the population gradient flow and online SGD, showing how alignment to the teacher subspace and prediction risk depend on the tail exponent , the widths and , and the optimization time. The authors introduce a matrix Riccati ODE framework to describe coupled feature-learning dynamics and develop operator-norm based discretization techniques to bound discrete SGD, obtaining nearly information-theoretic sample complexities in the isotropic case and revealing additive, phase-like learning behavior across many latent directions. The results demonstrate how anisotropic, power-law signal strengths shape convergence and sample efficiency, offering a principled view of feature learning in extensive-width nonlinear networks and guiding future extensions to more general activations. Overall, the paper provides rigorous scaling laws that connect optimization time, model width, and data size to the learned subspace and residual risk in quadratic neural nets.

Abstract

We study the optimization and sample complexity of gradient-based training of a two-layer neural network with quadratic activation function in the high-dimensional regime, where the data is generated as , is the 2nd Hermite polynomial, and are orthonormal signal directions. We consider the extensive-width regime for , and assume a power-law decay on the (non-negative) second-layer coefficients for . We present a sharp analysis of the SGD dynamics in the feature learning regime, for both the population limit and the finite-sample (online) discretization, and derive scaling laws for the prediction risk that highlight the power-law dependencies on the optimization time, sample size, and model width. Our analysis combines a precise characterization of the associated matrix Riccati differential equation with novel matrix monotonicity arguments to establish convergence guarantees for the infinite-dimensional effective dynamics.

Paper Structure

This paper contains 70 sections, 57 theorems, 439 equations, 6 figures, 1 table, 1 algorithm.

Key Result

Theorem 1

Let $\lambda_j = j^{-\alpha}$ and $r \asymp d^\beta$ for some $\alpha \geq 0$ and $\beta \in (0,1)$. Consider the regime Define the effective student width and effective timescale as Then, the population eq:gf dynamics satisfy the following with probability $1 - o(1/d^2) - \Omega(1/r_s^2)$:

Figures (6)

  • Figure 1: $(a)$ Illustration of the additive model hypothesis, i.e., sum of emergent learning curves at different timescales yields a power law in the cumulative loss. $(b)$ Population loss vs. compute for two-layer quadratic NNs trained with online SGD with batch size $d$ on squared loss. We set $d=3200$, and for the teacher model $r=2400$, $\alpha=1$.
  • Figure 2: Illustration of the limiting risk trajectories and scaling behavior given in Corollary \ref{['cor:asympriskcont']}.
  • Figure 3: Solutions of the matrix Riccati ODE in \ref{['eq:contriccati']} with $\lambda_1 = 2$, $\lambda_2 = 1$, $r_s = 2$. (a) To visualize the dynamics under matrix order, we plot the level sets of $\boldsymbol{G}(t)$ at times $t \in \{0, 0.25, 0.5\}$ for two initializations: $\boldsymbol{G}(0)$ (solid) and a scaled version $1.25\,\boldsymbol{G}(0)$ (dashed). The dashed ellipses remain enclosed within the solid ones at all times, illustrating monotonicity of the Riccati flow with respect to initialization. However, note that $\boldsymbol{G}(t)$ is not monotone in Loewner order over time, as seen from the lack of nesting among the solid ellipses. (b) Entry-wise evolution of $\boldsymbol{G}(t)$ under a random initialization with $d = 1024$. The diagonal entry $\boldsymbol{G}_{22}(t)$ exhibits non-monotonic behavior, illustrating that the solution trajectory $\boldsymbol{G}(t)$ need not be monotone in time; the off-diagonal entry $\boldsymbol{G}_{12}(t)$ is also shown for reference.
  • Figure 4: Population loss vs. compute for two-layer ReLU network (power-law second-layer with exponent $\alpha$) trained with population gradient descent. The student network adopts the 2-homogeneous parameterization as in \ref{['eq:student']}. Observe that after the initial loss drop due to the $\mathrm{He}_1$ component, the risk curves follow a power-law scaling where the exponent (dashed lines) nearly matches our theoretical prediction for the quadratic setting $\frac{1-2\alpha}{\alpha}$.
  • Figure 5: Solutions of the matrix Riccati ODE in \ref{['eq:contriccati']} with $\lambda_1 = 2$, $\lambda_2 = 1$, $r_s = 2$. (a) To visualize the dynamics under matrix order, we plot the level sets of $\boldsymbol{G}(t)$ at times $t \in \{0, 0.25, 0.5\}$ for two initializations: $\boldsymbol{G}(0)$ (solid) and a scaled version $1.25\,\boldsymbol{G}(0)$ (dashed). The dashed ellipses remain enclosed within the solid ones at all times, illustrating monotonicity of the Riccati flow with respect to initialization. However, note that $\boldsymbol{G}(t)$ is not monotone in Loewner order over time, as seen from the lack of nesting among the solid ellipses. (b) Entry-wise evolution of $\boldsymbol{G}(t)$ under a random initialization with $d = 1024$. The diagonal entry $\boldsymbol{G}_{22}(t)$ exhibits non-monotonic behavior, illustrating that the solution trajectory $\boldsymbol{G}(t)$ need not be monotone in time.
  • ...and 1 more figures

Theorems & Definitions (111)

  • Remark 1
  • Theorem 1
  • Remark 2
  • Corollary 1
  • Remark 3
  • Theorem 2
  • Remark 4
  • Corollary 2
  • Proposition 1
  • Remark 5
  • ...and 101 more