Connecting NTK and NNGP: A Unified Theoretical Framework for Wide Neural Network Learning Dynamics
Yehonatan Avidan, Qianyi Li, Haim Sompolinsky
TL;DR
This work presents a unified dynamical framework that merges Neural Tangent Kernel (NTK) and Neural Network Gaussian Process (NNGP) theories by modeling deep networks trained with Langevin gradient descent in the small-noise limit. Central to this framework is the Neural Dynamical Kernel (NDK), a time-dependent generalization of NTK that interpolates between NTK at early times and NNGP at long times, enabling a two-phase learning trajectory: a gradient-driven phase followed by a diffusive phase that samples the solution subspace. The authors derive a moment-generating functional for the predictor in the infinite-width limit, obtain recursive relations for time-dependent kernels, and analyze both generic nonlinearities and specific cases (e.g., ReLU, erf). They also connect these dynamics to representational drift in biological circuits and discuss conditions under which invariant codes can persist under drift, providing a comprehensive theory for learning dynamics in wide networks and potential insights into brain computation. The framework offers avenues for extending to finite-width, non-lazy regimes and for exploring how priors and architectural biases shape learning and representation over time.
Abstract
Artificial neural networks have revolutionized machine learning in recent years, but a complete theoretical framework for their learning process is still lacking. Substantial advances were achieved for wide networks, within two disparate theoretical frameworks: the Neural Tangent Kernel (NTK), which assumes linearized gradient descent dynamics, and the Bayesian Neural Network Gaussian Process (NNGP). We unify these two theories using gradient descent learning with an additional noise in an ensemble of wide deep networks. We construct an analytical theory for the network input-output function and introduce a new time-dependent Neural Dynamical Kernel (NDK) from which both NTK and NNGP kernels are derived. We identify two learning phases: a gradient-driven learning phase, dominated by loss minimization, in which the time scale is governed by the initialization variance. It is followed by a slow diffusive learning stage, where the parameters sample the solution space, with a time constant decided by the noise and the Bayesian prior variance. The two variance parameters strongly affect the performance in the two regimes, especially in sigmoidal neurons. In contrast to the exponential convergence of the mean predictor in the initial phase, the convergence to the equilibrium is more complex and may behave nonmonotonically. By characterizing the diffusive phase, our work sheds light on representational drift in the brain, explaining how neural activity changes continuously without degrading performance, either by ongoing gradient signals that synchronize the drifts of different synapses or by architectural biases that generate task-relevant information that is robust against the drift process. This work closes the gap between the NTK and NNGP theories, providing a comprehensive framework for the learning process of deep wide neural networks and for analyzing dynamics in biological circuits.
