Table of Contents
Fetching ...

Continual Learning of Diffusion Models with Generative Distillation

Sergi Masip, Pau Rodriguez, Tinne Tuytelaars, Gido M. van de Ven

TL;DR

Diffusion models are powerful but costly to train and prone to catastrophic forgetting when learning tasks sequentially. The authors introduce generative distillation, which distills the entire reverse diffusion process from a teacher to a student during continual learning, paired with generative replay to provide replay samples. Across Fashion-MNIST and CIFAR-10, generative distillation dramatically improves FID and KLD over standard generative replay and yields classifier performance approaching joint-task training, with only a modest additional cost. This work demonstrates the feasibility of continual learning for diffusion models and suggests avenues for scaling to larger datasets and faster samplers.

Abstract

Diffusion models are powerful generative models that achieve state-of-the-art performance in image synthesis. However, training them demands substantial amounts of data and computational resources. Continual learning would allow for incrementally learning new tasks and accumulating knowledge, thus enabling the reuse of trained models for further learning. One potentially suitable continual learning approach is generative replay, where a copy of a generative model trained on previous tasks produces synthetic data that are interleaved with data from the current task. However, standard generative replay applied to diffusion models results in a catastrophic loss in denoising capabilities. In this paper, we propose generative distillation, an approach that distils the entire reverse process of a diffusion model. We demonstrate that our approach substantially improves the continual learning performance of generative replay with only a modest increase in the computational costs.

Continual Learning of Diffusion Models with Generative Distillation

TL;DR

Diffusion models are powerful but costly to train and prone to catastrophic forgetting when learning tasks sequentially. The authors introduce generative distillation, which distills the entire reverse diffusion process from a teacher to a student during continual learning, paired with generative replay to provide replay samples. Across Fashion-MNIST and CIFAR-10, generative distillation dramatically improves FID and KLD over standard generative replay and yields classifier performance approaching joint-task training, with only a modest additional cost. This work demonstrates the feasibility of continual learning for diffusion models and suggests avenues for scaling to larger datasets and faster samplers.

Abstract

Diffusion models are powerful generative models that achieve state-of-the-art performance in image synthesis. However, training them demands substantial amounts of data and computational resources. Continual learning would allow for incrementally learning new tasks and accumulating knowledge, thus enabling the reuse of trained models for further learning. One potentially suitable continual learning approach is generative replay, where a copy of a generative model trained on previous tasks produces synthetic data that are interleaved with data from the current task. However, standard generative replay applied to diffusion models results in a catastrophic loss in denoising capabilities. In this paper, we propose generative distillation, an approach that distils the entire reverse process of a diffusion model. We demonstrate that our approach substantially improves the continual learning performance of generative replay with only a modest increase in the computational costs.
Paper Structure (35 sections, 13 equations, 28 figures, 8 tables, 1 algorithm)

This paper contains 35 sections, 13 equations, 28 figures, 8 tables, 1 algorithm.

Figures (28)

  • Figure 1: Conceptual flow of generative replay and generative distillation. In generative replay, the current task model $\epsilon_{\theta_{i}}$ is trained to match the noise added to the replay sample, whereas, in generative distillation, $\epsilon_{\theta_{i}}$ is trained to match the noise prediction of the previous task model $\epsilon_{\hat{\theta}_{i-1}}$.
  • Figure 2: Samples for generative replay on Fashion-MNIST (first row) and CIFAR-10 (second row). To generate samples for previous tasks, the teacher used 2 DDIM steps on Fashion-MNIST and 10 DDIM steps on CIFAR-10. Standard generative replay causes a catastrophic accumulation of error for samples from previous tasks.
  • Figure 3: Qualitative comparison of generative replay and generative distillation. The first row corresponds to results on Fashion-MNIST and the second row on CIFAR-10. (a) The order of the tasks the models were trained on. (b-d) Images generated at the end of training by a diffusion model jointly trained on all tasks (Joint), continually trained using generative replay (GR) or continually trained using generative distillation (GD). Introducing distillation into generative replay prevents the catastrophic accumulation of error and the model is still able to produce detailed images.
  • Figure 4: Quality metrics of the diffusion model throughout the continual learning process. For each approach, the FID and KLD of the diffusion model are computed after every task. Displayed are the means ± standard errors over 3 random seeds. Joint: training using data from all tasks so far (upper target), Naive: continual fine-tuning (lower target), GR: generative replay, GD: generative distillation.
  • Figure 5: Samples for different amounts of teacher steps on Fashion-MNIST. Generative distillation continues to outperform generative replay when the number of DDIM steps to generate replay samples for previous tasks is varied. For a quantitative evaluation, see \ref{['tab:results_cl_full']} in the Appendix.
  • ...and 23 more figures