Table of Contents
Fetching ...

Connecting NTK and NNGP: A Unified Theoretical Framework for Wide Neural Network Learning Dynamics

Yehonatan Avidan, Qianyi Li, Haim Sompolinsky

TL;DR

This work presents a unified dynamical framework that merges Neural Tangent Kernel (NTK) and Neural Network Gaussian Process (NNGP) theories by modeling deep networks trained with Langevin gradient descent in the small-noise limit. Central to this framework is the Neural Dynamical Kernel (NDK), a time-dependent generalization of NTK that interpolates between NTK at early times and NNGP at long times, enabling a two-phase learning trajectory: a gradient-driven phase followed by a diffusive phase that samples the solution subspace. The authors derive a moment-generating functional for the predictor in the infinite-width limit, obtain recursive relations for time-dependent kernels, and analyze both generic nonlinearities and specific cases (e.g., ReLU, erf). They also connect these dynamics to representational drift in biological circuits and discuss conditions under which invariant codes can persist under drift, providing a comprehensive theory for learning dynamics in wide networks and potential insights into brain computation. The framework offers avenues for extending to finite-width, non-lazy regimes and for exploring how priors and architectural biases shape learning and representation over time.

Abstract

Artificial neural networks have revolutionized machine learning in recent years, but a complete theoretical framework for their learning process is still lacking. Substantial advances were achieved for wide networks, within two disparate theoretical frameworks: the Neural Tangent Kernel (NTK), which assumes linearized gradient descent dynamics, and the Bayesian Neural Network Gaussian Process (NNGP). We unify these two theories using gradient descent learning with an additional noise in an ensemble of wide deep networks. We construct an analytical theory for the network input-output function and introduce a new time-dependent Neural Dynamical Kernel (NDK) from which both NTK and NNGP kernels are derived. We identify two learning phases: a gradient-driven learning phase, dominated by loss minimization, in which the time scale is governed by the initialization variance. It is followed by a slow diffusive learning stage, where the parameters sample the solution space, with a time constant decided by the noise and the Bayesian prior variance. The two variance parameters strongly affect the performance in the two regimes, especially in sigmoidal neurons. In contrast to the exponential convergence of the mean predictor in the initial phase, the convergence to the equilibrium is more complex and may behave nonmonotonically. By characterizing the diffusive phase, our work sheds light on representational drift in the brain, explaining how neural activity changes continuously without degrading performance, either by ongoing gradient signals that synchronize the drifts of different synapses or by architectural biases that generate task-relevant information that is robust against the drift process. This work closes the gap between the NTK and NNGP theories, providing a comprehensive framework for the learning process of deep wide neural networks and for analyzing dynamics in biological circuits.

Connecting NTK and NNGP: A Unified Theoretical Framework for Wide Neural Network Learning Dynamics

TL;DR

This work presents a unified dynamical framework that merges Neural Tangent Kernel (NTK) and Neural Network Gaussian Process (NNGP) theories by modeling deep networks trained with Langevin gradient descent in the small-noise limit. Central to this framework is the Neural Dynamical Kernel (NDK), a time-dependent generalization of NTK that interpolates between NTK at early times and NNGP at long times, enabling a two-phase learning trajectory: a gradient-driven phase followed by a diffusive phase that samples the solution subspace. The authors derive a moment-generating functional for the predictor in the infinite-width limit, obtain recursive relations for time-dependent kernels, and analyze both generic nonlinearities and specific cases (e.g., ReLU, erf). They also connect these dynamics to representational drift in biological circuits and discuss conditions under which invariant codes can persist under drift, providing a comprehensive theory for learning dynamics in wide networks and potential insights into brain computation. The framework offers avenues for extending to finite-width, non-lazy regimes and for exploring how priors and architectural biases shape learning and representation over time.

Abstract

Artificial neural networks have revolutionized machine learning in recent years, but a complete theoretical framework for their learning process is still lacking. Substantial advances were achieved for wide networks, within two disparate theoretical frameworks: the Neural Tangent Kernel (NTK), which assumes linearized gradient descent dynamics, and the Bayesian Neural Network Gaussian Process (NNGP). We unify these two theories using gradient descent learning with an additional noise in an ensemble of wide deep networks. We construct an analytical theory for the network input-output function and introduce a new time-dependent Neural Dynamical Kernel (NDK) from which both NTK and NNGP kernels are derived. We identify two learning phases: a gradient-driven learning phase, dominated by loss minimization, in which the time scale is governed by the initialization variance. It is followed by a slow diffusive learning stage, where the parameters sample the solution space, with a time constant decided by the noise and the Bayesian prior variance. The two variance parameters strongly affect the performance in the two regimes, especially in sigmoidal neurons. In contrast to the exponential convergence of the mean predictor in the initial phase, the convergence to the equilibrium is more complex and may behave nonmonotonically. By characterizing the diffusive phase, our work sheds light on representational drift in the brain, explaining how neural activity changes continuously without degrading performance, either by ongoing gradient signals that synchronize the drifts of different synapses or by architectural biases that generate task-relevant information that is robust against the drift process. This work closes the gap between the NTK and NNGP theories, providing a comprehensive framework for the learning process of deep wide neural networks and for analyzing dynamics in biological circuits.
Paper Structure (40 sections, 140 equations, 14 figures)

This paper contains 40 sections, 140 equations, 14 figures.

Figures (14)

  • Figure 1: Two Phases of Learning Dynamics: Simulation results of a deep network with a single hidden layer and error function activation, trained by Langevin dynamics (Eq.\ref{['eq:Langevin']}) on binary classification using two classes from CIFAR-10 dataset krizhevsky2014cifar. (a) Test loss: The mean squared error (MSE) loss on test data reveals two distinct phases: an initial fast, approximately deterministic stage culminating in convergence to a low error and a subsequent slow, stochastic exploration phase characterized by large fluctuations. At long times, the network converges to an equilibrium state where the statistics of the weights and performance stabilize over time. (b) Training loss: The loss on the training data shows rapid relaxation to a state with low training error, with fluctuations on the order of $\mathcal{O}(T)$, indicating the restricted diffusive dynamics in the subspace of low training error. (c) Weight dynamics: The weights exhibit a constrained random walk process, with their standard deviation gradually increasing from an initial value $\sigma_{0}$ to the equilibrium value $\sigma$ (in this example $\sigma >\sigma_{0}$). This stochastic process remains confined to the solutions subspace as evident by the low training error (b). Parameter and details are in Sec. \ref{['sec:Details-of-the']}.
  • Figure 2: The Neural Dynamical Kernel (NDK): The figure presents the NDK for various nonlinearities, parameters, and times, using examples from the MNIST dataset (0,1 digits). To focus on the kernel's structure independently of scale, the kernels are normalized by their maximum value. We present the NDK for equatl times ( $t=t'$, Eq.\ref{['eq:equal time']}) and for time difference from initialization (evaluating Eq.\ref{['eq:recursive kd']} at $t'=0$) (a-c) ReLU kernel (equal times) with parameters $\sigma_0=0.2$ and $\sigma=1$. Since ReLU is a homogeneous function, changes in the variance of the distribution do not alter the kernel's structure, which is preserved for all times. (d-f) ReLU kernel (time difference): The ReLU kernel depends on the angles between pairs of input vectors. As the time difference increases, the representations decouple, leaving only information about the amplitude of each example, $\left\Vert {\bf x}_{\mu}\right\Vert$, which is reflected in the rows and columns of the kernel. This uncorrelated kernel is critical for understanding representational drift (see Sec.\ref{['sec:representational drift']}) (g-i) Error Function kernel (equal times) with parameters $\sigma_0=0.2$ and $\sigma=10$. For small $\sigma_0$, the kernel resembles a linear kernel, closely reflecting the structure of the input. A large $\sigma$ causes a step-function-like behavior of the kernel, with a strong peak along the diagonal. (j-l) Error Function kernel (time difference) with parameters $\sigma_0=0.2$ and $\sigma=10$. The effect of time difference is similar to that of small variance, resulting in a kernel that resembles a linear kernel.
  • Figure 3: Gradient Driven Phase: NTK theory for a ReLU deep network with one hidden layer. The network is trained on binary classification in CIFAR-10. (a) The dynamics of the test bias, defined as $\left(\left\langle f\left({\bf x},t\right)\right\rangle -Y\right)^{2}$, averaged over the test dataset, show convergence to the NTK equilibrium. (b) The variance of the predictor, $\left\langle \delta f\left(t,{\bf x}\right)\delta f\left(t,{\bf x}\right)\right\rangle$, averaged over the test dataset, decreases with learning to an equilibrium value (Eq.\ref{['eq:ntk eq var']}). (c) The correlation with the initial condition, $\left\langle \delta f\left(t,{\bf x}\right)\delta f\left({\bf x},0\right)\right\rangle$, do not vanish in the NTK equilibrium at long times but rather go to an equilibrium value (Eq.\ref{['eq:ntk eq corr']}). This implies that long-term generalization depends on the random initialization of weights in deterministic gradient descent process.
  • Figure 4: Predictor Covariance in Linear Network: The theory of the predictor's covariance in linear network during the diffusive phase (a) The variance of the predictor, $\left\langle \delta f\left(t,{\bf x}\right)\delta f\left(t,{\bf x}\right)\right\rangle$, averaged on the test dataset, is shown for two values of $\sigma$ with $\sigma_0=1$. For $\sigma < \sigma_0$, the variance decreases during the diffusive learning phase due to the additional constraints imposed by L2 regularization. For $\sigma > \sigma_0$, the variance increases as the network explores the solution subspace. (b) The correlation with the initial condition, $\left\langle \delta f\left(t,{\bf x}\right)\delta f\left({\bf x},0\right)\right\rangle$. A rapid decrease during the gradient-driven phase followed by an exponential decay in the diffusive learning phase, reflecting the decorrelation caused by random changes in the weights.
  • Figure 5: Diffusive Dynamics: Numerical solutions of the mean field equations (Eqs.\ref{['eq:meanftrain']}, \ref{['eq:meanf']}) in the diffusive phase, starting from the NTK equilibrium, which marks the end of the gradient-driven phase. We evaluated the test bias, $\left(\left\langle f\left({\bf x},t\right)\right\rangle -y({\bf x})\right)^{2}$, averaged over the test dataset for various nonlinear functions and parameters in a binary classification task in CIFAR-10.$\quad$ (a-b) Performance comparison with error function activation for different $\sigma, \sigma_0$. Small values cause the sigmoidal function to behave like a linear function, hindering the generalization. Large values lead to strong nonlinearity and improved performance. (a) Small $\sigma_0$ and large $\sigma$. The accuracy after the exploratory diffusive phase is better by 10% compared to the gradient-driven phase. (b) Large $\sigma_0$ and small $\sigma$. The accuracy after the gradient-driven phase is better by 12% compared with the equilibrium accuracy (c-d) Comparison of the performance with ReLU activation for different $\sigma, \sigma_0$. The values of the NTK and NNGP equilibria do not depend on the values of $\sigma, \sigma_0$ (d) For $\sigma_0$ and small $\sigma$, the accuracy converges monotonously to the NNGP equilibrium, while in (c) for large $\sigma$ and small $\sigma_0$ the learning is non-monotonic.
  • ...and 9 more figures