Table of Contents
Fetching ...

Asymptotics of feature learning in two-layer networks after one gradient-step

Hugo Cui, Luca Pesce, Yatin Dandi, Florent Krzakala, Yue M. Lu, Lenka Zdeborová, Bruno Loureiro

TL;DR

This work analyzes how two‑layer neural networks learn features after a single large gradient step on the first layer, demonstrating that the trained network can be modeled by a spiked Random Features (sRF) framework. Using a replica‑method analysis and a conditional Gaussian equivalence, the authors derive exact high‑dimensional asymptotics for the sRF generalization error, revealing quantitative feature learning beyond the kernel regime and detailing how spike strength and data availability control learnability along the target direction. They map gradient‑trained networks to sRFs with explicit parameters $(c,r,\gamma)$, show the sRF can express nonlinear functions of the spike projection while remaining linear in orthogonal directions, and provide self‑consistent equations that yield the test error in closed form. The results explain when and how a single gradient step enables beating kernels, quantify the role of data in the first step, and discuss how variability in readout initialization expands the expressivity of the model. Overall, the paper provides a sharp, nonperturbative framework for understanding feature learning in gradient‑trained two‑layer networks and lays groundwork for extending to more steps and deeper architectures.

Abstract

In this manuscript, we investigate the problem of how two-layer neural networks learn features from data, and improve over the kernel regime, after being trained with a single gradient descent step. Leveraging the insight from (Ba et al., 2022), we model the trained network by a spiked Random Features (sRF) model. Further building on recent progress on Gaussian universality (Dandi et al., 2023), we provide an exact asymptotic description of the generalization error of the sRF in the high-dimensional limit where the number of samples, the width, and the input dimension grow at a proportional rate. The resulting characterization for sRFs also captures closely the learning curves of the original network model. This enables us to understand how adapting to the data is crucial for the network to efficiently learn non-linear functions in the direction of the gradient -- where at initialization it can only express linear functions in this regime.

Asymptotics of feature learning in two-layer networks after one gradient-step

TL;DR

This work analyzes how two‑layer neural networks learn features after a single large gradient step on the first layer, demonstrating that the trained network can be modeled by a spiked Random Features (sRF) framework. Using a replica‑method analysis and a conditional Gaussian equivalence, the authors derive exact high‑dimensional asymptotics for the sRF generalization error, revealing quantitative feature learning beyond the kernel regime and detailing how spike strength and data availability control learnability along the target direction. They map gradient‑trained networks to sRFs with explicit parameters , show the sRF can express nonlinear functions of the spike projection while remaining linear in orthogonal directions, and provide self‑consistent equations that yield the test error in closed form. The results explain when and how a single gradient step enables beating kernels, quantify the role of data in the first step, and discuss how variability in readout initialization expands the expressivity of the model. Overall, the paper provides a sharp, nonperturbative framework for understanding feature learning in gradient‑trained two‑layer networks and lays groundwork for extending to more steps and deeper architectures.

Abstract

In this manuscript, we investigate the problem of how two-layer neural networks learn features from data, and improve over the kernel regime, after being trained with a single gradient descent step. Leveraging the insight from (Ba et al., 2022), we model the trained network by a spiked Random Features (sRF) model. Further building on recent progress on Gaussian universality (Dandi et al., 2023), we provide an exact asymptotic description of the generalization error of the sRF in the high-dimensional limit where the number of samples, the width, and the input dimension grow at a proportional rate. The resulting characterization for sRFs also captures closely the learning curves of the original network model. This enables us to understand how adapting to the data is crucial for the network to efficiently learn non-linear functions in the direction of the gradient -- where at initialization it can only express linear functions in this regime.
Paper Structure (51 sections, 158 equations, 8 figures)

This paper contains 51 sections, 158 equations, 8 figures.

Figures (8)

  • Figure 1: Numerical estimation of the function $f(x_\theta)=\mathbb{E}_x[f_{W^{(1), \hat{\boldsymbol{a}}_\lambda}}|\boldsymbol{\theta}^\top \boldsymbol{x}/\sqrt{d}=x_\theta]$ implemented by the trained network \ref{['motiv:student']} in the direction spanned by the weights $\boldsymbol{\theta}$ of the target function. The activations are $\sigma=\sigma_\star=\tanh$, and simulations were run in dimensions $d=p=2000$, for a learning rate $\eta=2.5 p$, and a readout regularization $\lambda=0.01$. The readout was trained with $n_1=2d$ samples. Different colors corresponds to different sample complexities $\alpha_0\equiv n_0/d$ used to implement the gradient step on the first layer weights, with $\alpha_0=0$ corresponding to not implementing the step.
  • Figure 2: Crosses: numerical evaluation of the test error achieved by a two-layer network whose first layer has been trained following the protocol detailed in section \ref{['sec:motivation']}, with learning rate $\Tilde{\eta}=1$, activation $\sigma=\tanh$, and readout regularization (left) $\lambda=0.01$ (right) $\lambda=0.1$ . The target is a single index model with (left) tanh (right) sine activation. Numerical experiments were performed in $d=2000$. All points were averaged over $5$ instances. Different colours represent different initial sample complexities $\alpha_0=n_0/d$ used for the first gradient step. Solid lines: theoretical characterization of Result \ref{['res:asymptotics']} for the equivalent sRFs. The dashed black line represents the lowest achievable MSE for kernel/linear methods, namely $h^\star_2-(h^\star_1)^2$ba2022high.
  • Figure 3: ($c=1, r=0.9$ and $\gamma=1$) Illustration of the functions realizing the upper bound \ref{['eq:upper_bound']} (orange) and lower bound \ref{['eq:lower_bound']} (blue), for $\sigma=\tanh$, for a target $\sigma_\star=\sin$ (dashed black).
  • Figure 4: Test error for a sRF with (left) $\sigma=\sin$ (right) $\sigma=\tanh$ activation, learning from a single-index model (left) $\mathrm{sign}(\boldsymbol{\theta}^\top \boldsymbol{x}/\sqrt{d})$ (right) $\tanh(\boldsymbol{\theta}^\top \boldsymbol{x}/\sqrt{d})$, with regularization $\lambda=0.1$. Solid lines: theoretical characterization of Result \ref{['res:asymptotics']}. Crosses: numerical simulations in dimensions $d=p=2000$. Each point is averaged over $10$ instances of the problem. Different colours correspond to different spike strengths $r$\ref{['eq:W']}, with $r=0$ corresponding to the vanilla RF model.
  • Figure 5: $\sigma=\tanh,\sigma_\star=\sin, \alpha_0=1.5,\alpha=1.2, \lambda=0.1$ Test error as a function of the network width $\beta$, as predicted by the theoretical characterization of Result \ref{['res:asymptotics']} (blue) or measure in numerical simulations in $d=5000$ (red); error bars represent one standard deviation over $20$ trials.
  • ...and 3 more figures

Theorems & Definitions (1)

  • Definition 3.1: sRF model