Three Mechanisms of Feature Learning in a Linear Network
Yizhou Xu, Liu Ziyin
TL;DR
The paper addresses the problem of characterizing both kernel and feature-learning dynamics in finite-width neural networks by providing an exact solution for a minimal two-layer linear model with 1D data. It develops an analytically tractable framework that reduces the gradient-flow dynamics to a one-dimensional system for β=1, yielding explicit expressions and a phase diagram that separates kernel and feature-learning regimes across initialization and hyperparameters. The authors identify three high-signal feature-learning mechanisms—alignment, disalignment, and output rescaling—and show how they emerge only in the feature-learning regime, with empirical validation extending to deeper nonlinear networks. The work offers practical guidance on initialization and learning-rate choices to steer training toward productive feature-learning regimes and provides a bridge between finite-width and infinite-width analyses, with broad implications for understanding and designing training strategies.
Abstract
Understanding the dynamics of neural networks in different width regimes is crucial for improving their training and performance. We present an exact solution for the learning dynamics of a one-hidden-layer linear network, with one-dimensional data, across any finite width, uniquely exhibiting both kernel and feature learning phases. This study marks a technical advancement by enabling the analysis of the training trajectory from any initialization and a detailed phase diagram under varying common hyperparameters such as width, layer-wise learning rates, and scales of output and initialization. We identify three novel prototype mechanisms specific to the feature learning regime: (1) learning by alignment, (2) learning by disalignment, and (3) learning by rescaling, which contrast starkly with the dynamics observed in the kernel regime. Our theoretical findings are substantiated with empirical evidence showing that these mechanisms also manifest in deep nonlinear networks handling real-world tasks, enhancing our understanding of neural network training dynamics and guiding the design of more effective learning strategies.
