Continual Learning of Diffusion Models with Generative Distillation
Sergi Masip, Pau Rodriguez, Tinne Tuytelaars, Gido M. van de Ven
TL;DR
Diffusion models are powerful but costly to train and prone to catastrophic forgetting when learning tasks sequentially. The authors introduce generative distillation, which distills the entire reverse diffusion process from a teacher to a student during continual learning, paired with generative replay to provide replay samples. Across Fashion-MNIST and CIFAR-10, generative distillation dramatically improves FID and KLD over standard generative replay and yields classifier performance approaching joint-task training, with only a modest additional cost. This work demonstrates the feasibility of continual learning for diffusion models and suggests avenues for scaling to larger datasets and faster samplers.
Abstract
Diffusion models are powerful generative models that achieve state-of-the-art performance in image synthesis. However, training them demands substantial amounts of data and computational resources. Continual learning would allow for incrementally learning new tasks and accumulating knowledge, thus enabling the reuse of trained models for further learning. One potentially suitable continual learning approach is generative replay, where a copy of a generative model trained on previous tasks produces synthetic data that are interleaved with data from the current task. However, standard generative replay applied to diffusion models results in a catastrophic loss in denoising capabilities. In this paper, we propose generative distillation, an approach that distils the entire reverse process of a diffusion model. We demonstrate that our approach substantially improves the continual learning performance of generative replay with only a modest increase in the computational costs.
