Table of Contents
Fetching ...

Continual Learning with Deep Generative Replay

Hanul Shin, Jung Kwon Lee, Jaehong Kim, Jiwon Kim

TL;DR

Catastrophic forgetting hinders continual learning in deep nets. The authors propose Deep Generative Replay, pairing a generator with a solver to replay pseudo-data from past tasks without storing real past samples, trained via a mix of current and replayed data. The approach is evaluated on MNIST permutation tasks, cross-domain MNIST-SVHN transfer, and disjoint-class learning, showing sustained performance on old tasks while acquiring new skills. The results suggest generative replay can match joint-training performance when the generator accurately reproduces past input distributions and can complement existing methods like LwF and EWC.

Abstract

Attempts to train a comprehensive artificial intelligence capable of solving multiple tasks have been impeded by a chronic problem called catastrophic forgetting. Although simply replaying all previous data alleviates the problem, it requires large memory and even worse, often infeasible in real world applications where the access to past data is limited. Inspired by the generative nature of hippocampus as a short-term memory system in primate brain, we propose the Deep Generative Replay, a novel framework with a cooperative dual model architecture consisting of a deep generative model ("generator") and a task solving model ("solver"). With only these two models, training data for previous tasks can easily be sampled and interleaved with those for a new task. We test our methods in several sequential learning settings involving image classification tasks.

Continual Learning with Deep Generative Replay

TL;DR

Catastrophic forgetting hinders continual learning in deep nets. The authors propose Deep Generative Replay, pairing a generator with a solver to replay pseudo-data from past tasks without storing real past samples, trained via a mix of current and replayed data. The approach is evaluated on MNIST permutation tasks, cross-domain MNIST-SVHN transfer, and disjoint-class learning, showing sustained performance on old tasks while acquiring new skills. The results suggest generative replay can match joint-training performance when the generator accurately reproduces past input distributions and can complement existing methods like LwF and EWC.

Abstract

Attempts to train a comprehensive artificial intelligence capable of solving multiple tasks have been impeded by a chronic problem called catastrophic forgetting. Although simply replaying all previous data alleviates the problem, it requires large memory and even worse, often infeasible in real world applications where the access to past data is limited. Inspired by the generative nature of hippocampus as a short-term memory system in primate brain, we propose the Deep Generative Replay, a novel framework with a cooperative dual model architecture consisting of a deep generative model ("generator") and a task solving model ("solver"). With only these two models, training data for previous tasks can easily be sampled and interleaved with those for a new task. We test our methods in several sequential learning settings involving image classification tasks.

Paper Structure

This paper contains 13 sections, 2 equations, 7 figures, 1 table.

Figures (7)

  • Figure 1: Sequential training of scholar models. (a) Training a sequence of scholar models is equivalent to continuous training of a single scholar while referring to its most recent copy. (b) A new generator is trained to mimic a mixed data distribution of real samples $\boldsymbol{x}$ and replayed inputs $\boldsymbol{x}'$ from previous generator. (c) A new solver learns from real input-target pairs $(\boldsymbol{x}, \boldsymbol{y})$ and replayed input-target pairs $(\boldsymbol{x}', \boldsymbol{y}')$, where replayed response $\boldsymbol{y}'$ is obtained by feeding generated inputs into previous solver.
  • Figure 2: Results on MNIST pixel permutation tasks. (a) Test performances on each task during sequential training. Performances for previous tasks dropped without replaying real or meaningful fake data. (b) Average test accuracy on learnt tasks. Higher accuracy is achieved when the replayed inputs better resembled real data.
  • Figure 3: Accuracy on classifying samples from two different domains. (a) The models are trained on MNIST then on SVHN dataset or (b) vice versa. When the previous data are recalled by generative replay (orange), knowledge of the first domain is retained as if the real inputs with predicted responses are replayed (green). Sequential training on the solver alone incurs forgetting on the former domain, thereby resulting in low average performance (violet).
  • Figure 4: Samples from trained generator in MNIST to SVHN experiment after training on SVHN dataset for 1000, 2000, 5000, 10000, and 20000 iterations. The samples are diverted into ones that mimic either SVHN or MNIST input images.
  • Figure 5: Performance of LwF and LwF augmented with generative replay (LwF-GR) on classifying samples from each domain. The networks were trained on SVHN then on MNIST database. Test accuracy on SVHN classification task (thick curves) dropped when the shared parameters were fine-tuned, but generative replay greatly tempered the loss (orange). Both networks achieved high accuracy on MNIST classification (dim curves).
  • ...and 2 more figures

Theorems & Definitions (2)

  • Definition 1
  • Definition 2