Table of Contents
Fetching ...

Online Structured Laplace Approximations For Overcoming Catastrophic Forgetting

Hippolyt Ritter, Aleksandar Botev, David Barber

TL;DR

This work tackles catastrophic forgetting in neural networks by formulating Bayesian online learning with a Gaussian posterior update and a Laplace-based local approximation. It leverages a block-diagonal, Kronecker-factored Hessian to capture interdependencies among weights within the same layer while remaining scalable. The proposed Online Laplace method, especially with Kronecker factorization, substantially improves performance over diagonal-based approaches and baselines like EWC and SI on long sequences of tasks, including 50 permuted MNIST datasets and multiple vision benchmarks. The results underscore the importance of modeling weight interactions for robust continual learning and offer a scalable framework for Bayesian online continual learning.

Abstract

We introduce the Kronecker factored online Laplace approximation for overcoming catastrophic forgetting in neural networks. The method is grounded in a Bayesian online learning framework, where we recursively approximate the posterior after every task with a Gaussian, leading to a quadratic penalty on changes to the weights. The Laplace approximation requires calculating the Hessian around a mode, which is typically intractable for modern architectures. In order to make our method scalable, we leverage recent block-diagonal Kronecker factored approximations to the curvature. Our algorithm achieves over 90% test accuracy across a sequence of 50 instantiations of the permuted MNIST dataset, substantially outperforming related methods for overcoming catastrophic forgetting.

Online Structured Laplace Approximations For Overcoming Catastrophic Forgetting

TL;DR

This work tackles catastrophic forgetting in neural networks by formulating Bayesian online learning with a Gaussian posterior update and a Laplace-based local approximation. It leverages a block-diagonal, Kronecker-factored Hessian to capture interdependencies among weights within the same layer while remaining scalable. The proposed Online Laplace method, especially with Kronecker factorization, substantially improves performance over diagonal-based approaches and baselines like EWC and SI on long sequences of tasks, including 50 permuted MNIST datasets and multiple vision benchmarks. The results underscore the importance of modeling weight interactions for robust continual learning and offer a scalable framework for Bayesian online continual learning.

Abstract

We introduce the Kronecker factored online Laplace approximation for overcoming catastrophic forgetting in neural networks. The method is grounded in a Bayesian online learning framework, where we recursively approximate the posterior after every task with a Gaussian, leading to a quadratic penalty on changes to the weights. The Laplace approximation requires calculating the Hessian around a mode, which is typically intractable for modern architectures. In order to make our method scalable, we leverage recent block-diagonal Kronecker factored approximations to the curvature. Our algorithm achieves over 90% test accuracy across a sequence of 50 instantiations of the permuted MNIST dataset, substantially outperforming related methods for overcoming catastrophic forgetting.

Paper Structure

This paper contains 21 sections, 17 equations, 6 figures, 1 table.

Figures (6)

  • Figure 1: Mean test accuracy on a sequence of permuted MNIST datasets. We categorize SI as a diagonal method, as it does not account for parameter interactions. The dotted black line shows the performance of a single network trained on all observed data at each task.
  • Figure 2: Effect of $\lambda$ for different curvature approximations for permuted MNIST. Each plot shows the mean, minimum and maximum across the tasks observed so far, as well as the accuracy on the first and most recent task.
  • Figure 3: Disjoint MNIST test accuracy for the Laplace approximation (hyperparameter: $\lambda$) and SI (hyperparameter: $c$). 'Kronecker factored' and 'Diagonal' refer to the respective curvature approximation for the Laplace method.
  • Figure 4: Test accuracy of a convolutional network on a sequence of vision datasets. We train on the datasets separately in the order displayed from top to bottom and show the network's accuracy on each dataset once training on it has started. The dotted black line indicates the performance of a network with the same architecture trained separately on the task. The diagonal and Kronecker factored approximation to the Hessian both use our online Laplace method to prevent forgetting.
  • Figure 5: Contours of a Gaussian likelihood (dashed blue) and prior (shades of purple) for different values of $\lambda$. Values smaller than $1$ shift the joint maximum $\theta^*$, marked by a '${\times}$',towards that of the likelihood, i.e. the new task, for values greater than $1$ it moves towards the prior, i.e. previous tasks.
  • ...and 1 more figures