Joint Diffusion models in Continual Learning
Paweł Skierś, Kamil Deja
TL;DR
This work tackles catastrophic forgetting in continual learning by unifying a diffusion-based generator and a classifier into a single, jointly optimized model (JDCL). The method leverages a two-stage local-to-global training regime and knowledge distillation to stabilize learning while generating rehearsal data from the joint model, achieving state-of-the-art results among generative-replay methods on CIFAR-10, CIFAR-100, and ImageNet-100. It further extends to semi-supervised continual learning, where JDCL outperforms buffer-based replay and demonstrates strong representations in self-supervised settings. Overall, JDCL provides a scalable, efficient approach that integrates generation and discrimination for robust continual learning and representation learning.
Abstract
In this work, we introduce JDCL - a new method for continual learning with generative rehearsal based on joint diffusion models. Neural networks suffer from catastrophic forgetting defined as abrupt loss in the model's performance when retrained with additional data coming from a different distribution. Generative-replay-based continual learning methods try to mitigate this issue by retraining a model with a combination of new and rehearsal data sampled from a generative model. In this work, we propose to extend this idea by combining a continually trained classifier with a diffusion-based generative model into a single - jointly optimized neural network. We show that such shared parametrization, combined with the knowledge distillation technique allows for stable adaptation to new tasks without catastrophic forgetting. We evaluate our approach on several benchmarks, where it outperforms recent state-of-the-art generative replay techniques. Additionally, we extend our method to the semi-supervised continual learning setup, where it outperforms competing buffer-based replay techniques, and evaluate, in a self-supervised manner, the quality of trained representations.
