Table of Contents
Fetching ...

Generative Adversarial Network Training is a Continual Learning Problem

Kevin J Liang, Chunyuan Li, Guoyin Wang, Lawrence Carin

TL;DR

Generative Adversarial Networks (GANs) exhibit unstable training such as mode collapse and oscillations, in part due to catastrophic forgetting by the discriminator as the generator's distribution evolves. The authors frame GAN training as a continual learning problem and augment the discriminator with memory-based regularizers inspired by elastic weight consolidation and intelligent synapses, implemented in an online, resource-efficient manner to yield EWC-GAN and IS-GAN. Across toy and real datasets (eight Gaussians, CelebA, CIFAR-10, and COCO Captions), these methods improve generation quality (FID/ICP, BLEU) with minimal computational overhead and no need for additional networks. This work positions GANs as a realistic continual learning benchmark and demonstrates that memory-aware discrimination can stabilize training and enhance performance.

Abstract

Generative Adversarial Networks (GANs) have proven to be a powerful framework for learning to draw samples from complex distributions. However, GANs are also notoriously difficult to train, with mode collapse and oscillations a common problem. We hypothesize that this is at least in part due to the evolution of the generator distribution and the catastrophic forgetting tendency of neural networks, which leads to the discriminator losing the ability to remember synthesized samples from previous instantiations of the generator. Recognizing this, our contributions are twofold. First, we show that GAN training makes for a more interesting and realistic benchmark for continual learning methods evaluation than some of the more canonical datasets. Second, we propose leveraging continual learning techniques to augment the discriminator, preserving its ability to recognize previous generator samples. We show that the resulting methods add only a light amount of computation, involve minimal changes to the model, and result in better overall performance on the examined image and text generation tasks.

Generative Adversarial Network Training is a Continual Learning Problem

TL;DR

Generative Adversarial Networks (GANs) exhibit unstable training such as mode collapse and oscillations, in part due to catastrophic forgetting by the discriminator as the generator's distribution evolves. The authors frame GAN training as a continual learning problem and augment the discriminator with memory-based regularizers inspired by elastic weight consolidation and intelligent synapses, implemented in an online, resource-efficient manner to yield EWC-GAN and IS-GAN. Across toy and real datasets (eight Gaussians, CelebA, CIFAR-10, and COCO Captions), these methods improve generation quality (FID/ICP, BLEU) with minimal computational overhead and no need for additional networks. This work positions GANs as a realistic continual learning benchmark and demonstrates that memory-aware discrimination can stabilize training and enhance performance.

Abstract

Generative Adversarial Networks (GANs) have proven to be a powerful framework for learning to draw samples from complex distributions. However, GANs are also notoriously difficult to train, with mode collapse and oscillations a common problem. We hypothesize that this is at least in part due to the evolution of the generator distribution and the catastrophic forgetting tendency of neural networks, which leads to the discriminator losing the ability to remember synthesized samples from previous instantiations of the generator. Recognizing this, our contributions are twofold. First, we show that GAN training makes for a more interesting and realistic benchmark for continual learning methods evaluation than some of the more canonical datasets. Second, we propose leveraging continual learning techniques to augment the discriminator, preserving its ability to recognize previous generator samples. We show that the resulting methods add only a light amount of computation, involve minimal changes to the model, and result in better overall performance on the examined image and text generation tasks.

Paper Structure

This paper contains 19 sections, 7 equations, 6 figures, 5 tables, 2 algorithms.

Figures (6)

  • Figure 1: Real samples from a mixture of eight Gaussians in red; generated samples in blue. (a) The generator is mode-collapsed in the bottom right. (b) The discriminator learns to recognize the generator oversampling this region and pushes the generator away, so the generator gravitates toward a new mode. (c) The discriminator continues to chase the generator, causing the generator to move in a clockwise direction. (d) The generator eventually returns to the same mode as (a). Such oscillations are common while training a vanilla GAN. Best seen as a video: https://youtu.be/91a2gPWngo8.
  • Figure 2: Each line represents the discriminator's test accuracy on the fake GAN datasets. Note the sharp decrease in the discriminator's ability to recognize previous fake samples upon fine-tuning on the next dataset using SGD (left). Forgetting still occurs with EWC (right), but is less severe.
  • Figure 3: Image samples from generated "fake MNIST" datasets
  • Figure 4: Each row shows the evolution of generator samples at 5000 training step intervals for GAN, SN-GAN, and EWC-GAN for two $\alpha$ values. The proposed EWC-GAN models have hyperparameters matching the corresponding $\alpha$ in Table \ref{['tab:8-Gauss']}. Each frame shows 10000 samples drawn from the true eight Gaussians mixture (red) and 10000 generator samples (blue).
  • Figure 5: Generated image samples from random draws of EWC+GANs.
  • ...and 1 more figures