Table of Contents
Fetching ...

Continual learning with the neural tangent ensemble

Ari S. Benjamin, Christian Pehle, Kyle Daruwalla

TL;DR

This work reframes neural networks as Bayesian ensembles of neural tangent experts (NTEs), enabling continual learning without forgetting by posterior weighting over fixed tangent components in the lazy regime. It shows that a first-order Taylor expansion around a seed point makes the network an ensemble of $N$ classifiers, each contributing a probability distribution, with learning corresponding to updating ensemble weights via a posterior that is nearly equivalent to stochastic gradient descent on the initialization. In finite-width (rich-regime) networks, the tangent experts become adaptive, but the authors derive a practical NTE rule using current gradients that remains effective when networks scale, and they demonstrate that momentum harms forgetting while width can improve retention under certain optimizers. The framework provides a principled Bayesian interpretation of forgetting and suggests concrete directions to mitigate it, including using near-initialization dynamics and carefully chosen optimizers, with broad implications for understanding and improving continual learning in deep models.

Abstract

A natural strategy for continual learning is to weigh a Bayesian ensemble of fixed functions. This suggests that if a (single) neural network could be interpreted as an ensemble, one could design effective algorithms that learn without forgetting. To realize this possibility, we observe that a neural network classifier with N parameters can be interpreted as a weighted ensemble of N classifiers, and that in the lazy regime limit these classifiers are fixed throughout learning. We call these classifiers the neural tangent experts and show they output valid probability distributions over the labels. We then derive the likelihood and posterior probability of each expert given past data. Surprisingly, the posterior updates for these experts are equivalent to a scaled and projected form of stochastic gradient descent (SGD) over the network weights. Away from the lazy regime, networks can be seen as ensembles of adaptive experts which improve over time. These results offer a new interpretation of neural networks as Bayesian ensembles of experts, providing a principled framework for understanding and mitigating catastrophic forgetting in continual learning settings.

Continual learning with the neural tangent ensemble

TL;DR

This work reframes neural networks as Bayesian ensembles of neural tangent experts (NTEs), enabling continual learning without forgetting by posterior weighting over fixed tangent components in the lazy regime. It shows that a first-order Taylor expansion around a seed point makes the network an ensemble of classifiers, each contributing a probability distribution, with learning corresponding to updating ensemble weights via a posterior that is nearly equivalent to stochastic gradient descent on the initialization. In finite-width (rich-regime) networks, the tangent experts become adaptive, but the authors derive a practical NTE rule using current gradients that remains effective when networks scale, and they demonstrate that momentum harms forgetting while width can improve retention under certain optimizers. The framework provides a principled Bayesian interpretation of forgetting and suggests concrete directions to mitigate it, including using near-initialization dynamics and carefully chosen optimizers, with broad implications for understanding and improving continual learning in deep models.

Abstract

A natural strategy for continual learning is to weigh a Bayesian ensemble of fixed functions. This suggests that if a (single) neural network could be interpreted as an ensemble, one could design effective algorithms that learn without forgetting. To realize this possibility, we observe that a neural network classifier with N parameters can be interpreted as a weighted ensemble of N classifiers, and that in the lazy regime limit these classifiers are fixed throughout learning. We call these classifiers the neural tangent experts and show they output valid probability distributions over the labels. We then derive the likelihood and posterior probability of each expert given past data. Surprisingly, the posterior updates for these experts are equivalent to a scaled and projected form of stochastic gradient descent (SGD) over the network weights. Away from the lazy regime, networks can be seen as ensembles of adaptive experts which improve over time. These results offer a new interpretation of neural networks as Bayesian ensembles of experts, providing a principled framework for understanding and mitigating catastrophic forgetting in continual learning settings.
Paper Structure (30 sections, 5 theorems, 27 equations, 10 figures, 1 algorithm)

This paper contains 30 sections, 5 theorems, 27 equations, 10 figures, 1 algorithm.

Key Result

Lemma 1

Invariance to data ordering in Bayesian Ensembles. Let $\mathcal{F} = {f_1, ..., f_N}$ be a set of fixed experts, $\mathcal{W}=w_1,...,w_N$ be their weights, and $\mathcal{D} = {D_1, ..., D_T}$ be a sequence of datasets from $T$ tasks. Then, for any permutation $\pi$ of the indices 1, ..., T, $p(f_i

Figures (10)

  • Figure 1: High-level intuition for model averaging and continual learning. Pruning the set of functions $f_i$ to those good for task $\mathcal{A}$, followed by further pruning for tasks $\mathcal{B}$ and $\mathcal{C}$, will result in a set of $f_i$ still good on $\mathcal{A}$.
  • Figure 2: The average squared difference between experts' columns of the Jacobian measured at initialization and the end of training on MNIST with an 2-layer ReLU MLP and the NTE rule. Error bands indicate the standard deviation over 10 random seeds. As the width of the network increases, the average distance decreases, indicating the larger networks remain closer to the original linearization.
  • Figure 3: a) Gradients of an MLP at time $(t)$ rapidly lose correlation with the gradients at initialization. b) Training a network with the NTE posterior update rule fails when gradients diverge. Hyperparameters are reported in the Appendix.
  • Figure 4: Effect of momentum in SGD on the Permuted MNIST task for an MLP with 2 layers and 1,000 hidden units. (middle) Test accuracy on the first task at the end of training 5 sequential tasks. (right) Final test accuracy on the first task before seeing the other tasks. Error bars represent standard deviations over seeds. See Appendix for further parameters.
  • Figure 5: Wider networks forget less, unless trained with Adam. See Alg. \ref{['alg:cap']}. All networks are 2-layer MLPs with ReLU nonlinearities trained on 5 Permuted MNIST tasks. Loss curves and further parameters in Appendix. Error bars represent standard deviations.
  • ...and 5 more figures

Theorems & Definitions (9)

  • Lemma 1
  • proof
  • Theorem 2
  • Lemma 3
  • proof
  • Lemma 4
  • proof
  • Lemma 5
  • proof