Table of Contents
Fetching ...

Joint Diffusion models in Continual Learning

Paweł Skierś, Kamil Deja

TL;DR

This work tackles catastrophic forgetting in continual learning by unifying a diffusion-based generator and a classifier into a single, jointly optimized model (JDCL). The method leverages a two-stage local-to-global training regime and knowledge distillation to stabilize learning while generating rehearsal data from the joint model, achieving state-of-the-art results among generative-replay methods on CIFAR-10, CIFAR-100, and ImageNet-100. It further extends to semi-supervised continual learning, where JDCL outperforms buffer-based replay and demonstrates strong representations in self-supervised settings. Overall, JDCL provides a scalable, efficient approach that integrates generation and discrimination for robust continual learning and representation learning.

Abstract

In this work, we introduce JDCL - a new method for continual learning with generative rehearsal based on joint diffusion models. Neural networks suffer from catastrophic forgetting defined as abrupt loss in the model's performance when retrained with additional data coming from a different distribution. Generative-replay-based continual learning methods try to mitigate this issue by retraining a model with a combination of new and rehearsal data sampled from a generative model. In this work, we propose to extend this idea by combining a continually trained classifier with a diffusion-based generative model into a single - jointly optimized neural network. We show that such shared parametrization, combined with the knowledge distillation technique allows for stable adaptation to new tasks without catastrophic forgetting. We evaluate our approach on several benchmarks, where it outperforms recent state-of-the-art generative replay techniques. Additionally, we extend our method to the semi-supervised continual learning setup, where it outperforms competing buffer-based replay techniques, and evaluate, in a self-supervised manner, the quality of trained representations.

Joint Diffusion models in Continual Learning

TL;DR

This work tackles catastrophic forgetting in continual learning by unifying a diffusion-based generator and a classifier into a single, jointly optimized model (JDCL). The method leverages a two-stage local-to-global training regime and knowledge distillation to stabilize learning while generating rehearsal data from the joint model, achieving state-of-the-art results among generative-replay methods on CIFAR-10, CIFAR-100, and ImageNet-100. It further extends to semi-supervised continual learning, where JDCL outperforms buffer-based replay and demonstrates strong representations in self-supervised settings. Overall, JDCL provides a scalable, efficient approach that integrates generation and discrimination for robust continual learning and representation learning.

Abstract

In this work, we introduce JDCL - a new method for continual learning with generative rehearsal based on joint diffusion models. Neural networks suffer from catastrophic forgetting defined as abrupt loss in the model's performance when retrained with additional data coming from a different distribution. Generative-replay-based continual learning methods try to mitigate this issue by retraining a model with a combination of new and rehearsal data sampled from a generative model. In this work, we propose to extend this idea by combining a continually trained classifier with a diffusion-based generative model into a single - jointly optimized neural network. We show that such shared parametrization, combined with the knowledge distillation technique allows for stable adaptation to new tasks without catastrophic forgetting. We evaluate our approach on several benchmarks, where it outperforms recent state-of-the-art generative replay techniques. Additionally, we extend our method to the semi-supervised continual learning setup, where it outperforms competing buffer-based replay techniques, and evaluate, in a self-supervised manner, the quality of trained representations.

Paper Structure

This paper contains 36 sections, 12 equations, 9 figures, 11 tables, 1 algorithm.

Figures (9)

  • Figure 1: Further training of a classifier using data generated from the diffusion model trained with the same dataset harms the classifier's performance (blue line). However, joint modeling, especially in combination with knowledge distillation, significantly limits this degradation and retains almost the initial performance. We report accuracy averaged across three different seeds.
  • Figure 2: Overview of JDCL. To adapt the global model to task $\tau+1$, we first train a local joint diffusion model on task $\tau+1$ only. We then use the global model and the local model to generate a synthetic dataset consisting of samples from tasks $1, 2 \dots \tau+1$. Finally, we fine-tune the global model with the generated dataset and a combination of joint diffusion and knowledge distillation losses.
  • Figure 3: Accuracy on each task after each phase of incremental training on CIFAR100 with 5 tasks.
  • Figure 4: Difference between the distribution of logit value corresponding to class 1 from CIFAR-10 for real and synthetic data. Samples generated by the joint model closely match the real data's logits distribution, whereas those from the separately trained classifier diverge significantly.
  • Figure 5: Mean accuracy after each task for the ablation methods (left) and accuracy of JDCL on current and the first task with and without two-stage training (right)
  • ...and 4 more figures