Table of Contents
Fetching ...

Bayesian continual learning and forgetting in neural networks

Djohan Bonnet, Kellian Cottart, Tifenn Hirtzlin, Tarcisius Januel, Thomas Dalgaty, Elisa Vianello, Damien Querlioz

TL;DR

MESU introduces Metaplasticity from Synaptic Uncertainty, a Bayesian continual-learning framework that updates neural weights according to their uncertainty to balance learning and forgetting. By maintaining a truncated posterior over the last $N$ tasks and minimizing a variational free-energy $\mathcal{F}_t$, MESU achieves principled forgetting, preserves essential past knowledge, and scales updates by weight uncertainty through the rule $\Delta\bm{\mu}$ and $\Delta\bm{\sigma}$. The paper shows theoretical links to Hessian-based regularization and to Newton’s method, and demonstrates strong empirical performance on domain-incremental animals, Permuted MNIST, and CIFAR-10/100, outperforming boundary-based methods and avoiding both catastrophic forgetting and catastrophic remembering while retaining robust epistemic uncertainty for out-of-distribution detection. The work provides a biologically inspired, boundary-free path toward robust perpetual learning in streaming data.

Abstract

Biological synapses effortlessly balance memory retention and flexibility, yet artificial neural networks still struggle with the extremes of catastrophic forgetting and catastrophic remembering. Here, we introduce Metaplasticity from Synaptic Uncertainty (MESU), a Bayesian framework that updates network parameters according their uncertainty. This approach allows a principled combination of learning and forgetting that ensures that critical knowledge is preserved while unused or outdated information is gradually released. Unlike standard Bayesian approaches -- which risk becoming overly constrained, and popular continual-learning methods that rely on explicit task boundaries, MESU seamlessly adapts to streaming data. It further provides reliable epistemic uncertainty estimates, allowing out-of-distribution detection, the only computational cost being to sample the weights multiple times to provide proper output statistics. Experiments on image-classification benchmarks demonstrate that MESU mitigates catastrophic forgetting, while maintaining plasticity for new tasks. When training 200 sequential permuted MNIST tasks, MESU outperforms established continual learning techniques in terms of accuracy, capability to learn additional tasks, and out-of-distribution data detection. Additionally, due to its non-reliance on task boundaries, MESU outperforms conventional learning techniques on the incremental training of CIFAR-100 tasks consistently in a wide range of scenarios. Our results unify ideas from metaplasticity, Bayesian inference, and Hessian-based regularization, offering a biologically-inspired pathway to robust, perpetual learning.

Bayesian continual learning and forgetting in neural networks

TL;DR

MESU introduces Metaplasticity from Synaptic Uncertainty, a Bayesian continual-learning framework that updates neural weights according to their uncertainty to balance learning and forgetting. By maintaining a truncated posterior over the last tasks and minimizing a variational free-energy , MESU achieves principled forgetting, preserves essential past knowledge, and scales updates by weight uncertainty through the rule and . The paper shows theoretical links to Hessian-based regularization and to Newton’s method, and demonstrates strong empirical performance on domain-incremental animals, Permuted MNIST, and CIFAR-10/100, outperforming boundary-based methods and avoiding both catastrophic forgetting and catastrophic remembering while retaining robust epistemic uncertainty for out-of-distribution detection. The work provides a biologically inspired, boundary-free path toward robust perpetual learning in streaming data.

Abstract

Biological synapses effortlessly balance memory retention and flexibility, yet artificial neural networks still struggle with the extremes of catastrophic forgetting and catastrophic remembering. Here, we introduce Metaplasticity from Synaptic Uncertainty (MESU), a Bayesian framework that updates network parameters according their uncertainty. This approach allows a principled combination of learning and forgetting that ensures that critical knowledge is preserved while unused or outdated information is gradually released. Unlike standard Bayesian approaches -- which risk becoming overly constrained, and popular continual-learning methods that rely on explicit task boundaries, MESU seamlessly adapts to streaming data. It further provides reliable epistemic uncertainty estimates, allowing out-of-distribution detection, the only computational cost being to sample the weights multiple times to provide proper output statistics. Experiments on image-classification benchmarks demonstrate that MESU mitigates catastrophic forgetting, while maintaining plasticity for new tasks. When training 200 sequential permuted MNIST tasks, MESU outperforms established continual learning techniques in terms of accuracy, capability to learn additional tasks, and out-of-distribution data detection. Additionally, due to its non-reliance on task boundaries, MESU outperforms conventional learning techniques on the incremental training of CIFAR-100 tasks consistently in a wide range of scenarios. Our results unify ideas from metaplasticity, Bayesian inference, and Hessian-based regularization, offering a biologically-inspired pathway to robust, perpetual learning.

Paper Structure

This paper contains 9 sections, 5 theorems, 42 equations, 5 figures, 4 tables.

Key Result

Lemma 1

Let $q_{\bm{\theta}}(\bm{\omega}) \approx p(\bm{\omega}\mid \mathcal{D})$ be a mean-field Gaussian for a Bayesian neural network, where $\bm{\theta}=(\bm{\mu},\bm{\sigma})$ and $\bm{\omega}=\bm{\mu} + \bm{\epsilon}\cdot\bm{\sigma}$, $\bm{\epsilon}\sim \mathcal{N}(\vec{0},\mathbf{I}_s)$. If the prior

Figures (5)

  • Figure 1: Bayesian continual learning and forgetting. Continual learning is a sequential training situation, where several datasets $\mathcal{D}_{i}$ are presented sequentially. In our approach the weights of a neural network follow a probability distribution $q_{\bm{\theta}_t}(\bm{\omega})$. The target for learning is that this distribution approximates $p(\bm{\omega}\mid \mathcal{D}_{t-N},\dots,\mathcal{D}_t),$ a formulation that gracefully balances learning and forgetting.
  • Figure 2: Domain incremental learning with animals' classification.a Example of images in the dataset. Each superclass corresponds to a family of animals. Among 20 sub-classes, five are selected randomly to belong to task i. For a given superclass, the specific species presented during training changes at each new task. b Example of an image that belongs to the "out of distribution" dataset. Those images are here to evaluate the model's capability to detect unknown images. c Evolution of the accuracy for each task after we learn a new task with a Bayesian neural network trained with MESU. The experiment was run 50 times with different possible combinations of datasets. The shaded area represents the standard deviation of the accuracy. d Evolution of the accuracy for each task after we learn a new task, with a deterministic neural network trained with Stochastic Gradient Descent (SGD). The experiment was run 50 times with different possible combinations of datasets. The shaded area represents one standard deviation of the accuracy. e Distribution of the epistemic uncertainty of the Bayesian neural network trained with MESU for the out-of-distribution dataset and the in-distribution dataset (Test dataset). f Distribution of the aleatoric uncertainty of the deterministic neural network trained with SGD for the out-of-distribution dataset and the in-distribution dataset (Test dataset).
  • Figure 3: Comparison between Metaplasticity from Synaptic Uncertainty(MESU), Fixed-point Operator for Online Variational Bayes Diagonal (FOO-VB Diagonal), Elastic Weight Consolidation Online (EWC Online, which uses task boundaries), Elastic Weight Consolidation Stream (EWC Stream, which does not use task boundaries), Synaptic Intelligence (SI), and Stochastic Gradient Descent (SGD) as a baseline on 200 tasks of Permuted MNIST in a streaming learning context with low number of parameters (50 hidden units).a Comparison between algorithms of the average accuracy on tasks 196 to 200 after learning 200 permutations of MNIST. b Testing accuracy on the test sets of the old tasks after training for 200 tasks. c Comparison of the memory rigidity, corresponding to the inverse absolute value of the difference between the accuracy reached during the first task and the accuracy reached during the last task $\mathcal{R}_t = \frac{1}{|\mathcal{A}_0 - \mathcal{A}_t|}$. d Testing accuracy on the test set of the newly learned task after training on said task. e Comparison between algorithms of the ability to discriminate between the last permutation learned and Fashion-MNIST using the area under curve of the receiver operating characteristic (ROC AUC) between in-distribution data uncertainty and out-of-distribution data uncertainty. ROC AUC is computed at the end of the last trained task with Permuted MNIST test dataset predictions as in-distribution and Fashion-MNIST test dataset predictions as out-of-distribution. Shadings represent one standard deviation over five runs.
  • Figure 4: MESU is resilient to vanishing uncertainty.a Evolution of the ROC AUC for out-of-distribution (OOD) detection when training on MNIST for 1000 epochs and using Fashion-MNIST as OOD data. We compare MESU (with two different remembering window sizes, $N$), FOO-VB Diagonal, EWC Stream, and a baseline that applies plain SGD with no continual-learning adaptation. The OOD detection procedure is identical to that in Fig. \ref{['fig:Figure2']}. b Evolution of the mean standard deviation $\sigma$ (averaged over all layers) for the Bayesian models. FOO-VB Diagonal and MESU with $N=10^{15}$ have matching results, consistently with the theoretical equivalence of MESU and FOO-VB in the limit of infinite $N$ (i.e., no forgetting). Shading denotes one standard deviation over five runs.
  • Figure 5: Comparison of task Incremental Learning with CIFAR-10 and CIFAR-100 with Metaplasticity from Synaptic Uncertainty (MESU), Elastic Weight Consolidation (EWC), Synaptic Intelligence (SI) and a baseline with no continual learning adapatation (Adam)a Mean accuracy obtained after training the 11 CIFAR-10/CIFAR-100 tasks, in different continual learning situations (different number of splits, see main text). Shadings represent one standard deviation. b, c Comparison between algorithms of the accuracy across the 11 tasks in the single split case (b) and in the 16-splits case (c). Lines are guides for the eyes.

Theorems & Definitions (5)

  • Lemma 1: The quadratic negative log-likelihood
  • Theorem 1: MESU
  • Lemma 2: Second-order derivative via first-order derivative
  • Theorem 2: Newton’s method in variational inference
  • Proposition 1: Dynamics of standard deviations