Dichotomy of Feature Learning and Unlearning: Fast-Slow Analysis on Neural Networks with Stochastic Gradient Descent
Shota Imai, Sota Nishiyama, Masaaki Imaizumi
TL;DR
This work analyzes gradient-based neural network training under SGD from a high-dimensional, infinite-width perspective. By deriving a two-dimensional ODE for macroscopic variables and revealing a fast-slow decomposition, the authors show that feature unlearning emerges from slow dynamics along a critical manifold, with a precise scaling law for the decay of alignment and growth of second-layer weights. They ground the analysis with Tensor Programs and singular perturbation theory, validate it numerically, and corroborate it with experiments on real networks. The findings quantify when unlearning occurs, relate it to the data-generating nonlinearity and initialization, and illuminate how learning can transition into a lazy regime despite ongoing optimization. The results have implications for understanding long-term feature retention, SGD dynamics, and stability of learned representations in deep networks.
Abstract
The dynamics of gradient-based training in neural networks often exhibit nontrivial structures; hence, understanding them remains a central challenge in theoretical machine learning. In particular, a concept of feature unlearning, in which a neural network progressively loses previously learned features over long training, has gained attention. In this study, we consider the infinite-width limit of a two-layer neural network updated with a large-batch stochastic gradient, then derive differential equations with different time scales, revealing the mechanism and conditions for feature unlearning to occur. Specifically, we utilize the fast-slow dynamics: while an alignment of first-layer weights develops rapidly, the second-layer weights develop slowly. The direction of a flow on a critical manifold, determined by the slow dynamics, decides whether feature unlearning occurs. We give numerical validation of the result, and derive theoretical grounding and scaling laws of the feature unlearning. Our results yield the following insights: (i) the strength of the primary nonlinear term in data induces the feature unlearning, and (ii) an initial scale of the second-layer weights mitigates the feature unlearning. Technically, our analysis utilizes Tensor Programs and the singular perturbation theory.
