Asymptotics of feature learning in two-layer networks after one gradient-step
Hugo Cui, Luca Pesce, Yatin Dandi, Florent Krzakala, Yue M. Lu, Lenka Zdeborová, Bruno Loureiro
TL;DR
This work analyzes how two‑layer neural networks learn features after a single large gradient step on the first layer, demonstrating that the trained network can be modeled by a spiked Random Features (sRF) framework. Using a replica‑method analysis and a conditional Gaussian equivalence, the authors derive exact high‑dimensional asymptotics for the sRF generalization error, revealing quantitative feature learning beyond the kernel regime and detailing how spike strength and data availability control learnability along the target direction. They map gradient‑trained networks to sRFs with explicit parameters $(c,r,\gamma)$, show the sRF can express nonlinear functions of the spike projection while remaining linear in orthogonal directions, and provide self‑consistent equations that yield the test error in closed form. The results explain when and how a single gradient step enables beating kernels, quantify the role of data in the first step, and discuss how variability in readout initialization expands the expressivity of the model. Overall, the paper provides a sharp, nonperturbative framework for understanding feature learning in gradient‑trained two‑layer networks and lays groundwork for extending to more steps and deeper architectures.
Abstract
In this manuscript, we investigate the problem of how two-layer neural networks learn features from data, and improve over the kernel regime, after being trained with a single gradient descent step. Leveraging the insight from (Ba et al., 2022), we model the trained network by a spiked Random Features (sRF) model. Further building on recent progress on Gaussian universality (Dandi et al., 2023), we provide an exact asymptotic description of the generalization error of the sRF in the high-dimensional limit where the number of samples, the width, and the input dimension grow at a proportional rate. The resulting characterization for sRFs also captures closely the learning curves of the original network model. This enables us to understand how adapting to the data is crucial for the network to efficiently learn non-linear functions in the direction of the gradient -- where at initialization it can only express linear functions in this regime.
