Table of Contents
Fetching ...

Class-Prototype Conditional Diffusion Model with Gradient Projection for Continual Learning

Khanh Doan, Quyen Tran, Tung Lam Tran, Tuan Nguyen, Dinh Phung, Trung Le

TL;DR

The paper tackles catastrophic forgetting in continual learning by improving generative replay with diffusion models. It introduces Class-Prototype Conditional Diffusion Model with Gradient Projection (GPPDM), which uses learnable class prototypes to condition the diffusion process and a gradient-projection mechanism on cross-attention to preserve old-task representations. Across CI and CIR benchmarks (CIFAR-100, ImageNet, CUB-200, CORe50), GPPDM achieves higher average accuracy and lower forgetting than strong baselines such as DDGR, while requiring memory roughly equivalent to storing one prototype per class. The approach yields higher fidelity replay data and demonstrates practical memory efficiency, providing a concrete framework for more robust continual learning with diffusion-based generative replay.

Abstract

Mitigating catastrophic forgetting is a key hurdle in continual learning. Deep Generative Replay (GR) provides techniques focused on generating samples from prior tasks to enhance the model's memory capabilities using generative AI models ranging from Generative Adversarial Networks (GANs) to the more recent Diffusion Models (DMs). A major issue is the deterioration in the quality of generated data compared to the original, as the generator continuously self-learns from its outputs. This degradation can lead to the potential risk of catastrophic forgetting (CF) occurring in the classifier. To address this, we propose the Gradient Projection Class-Prototype Conditional Diffusion Model (GPPDM), a GR-based approach for continual learning that enhances image quality in generators and thus reduces the CF in classifiers. The cornerstone of GPPDM is a learnable class prototype that captures the core characteristics of images in a given class. This prototype, integrated into the diffusion model's denoising process, ensures the generation of high-quality images of the old tasks, hence reducing the risk of CF in classifiers. Moreover, to further mitigate the CF of diffusion models, we propose a gradient projection technique tailored for the cross-attention layer of diffusion models to maximally maintain and preserve the representations of old task data in the current task as close as possible to their representations when they first arrived. Our empirical studies on diverse datasets demonstrate that our proposed method significantly outperforms existing state-of-the-art models, highlighting its satisfactory ability to preserve image quality and enhance the model's memory retention.

Class-Prototype Conditional Diffusion Model with Gradient Projection for Continual Learning

TL;DR

The paper tackles catastrophic forgetting in continual learning by improving generative replay with diffusion models. It introduces Class-Prototype Conditional Diffusion Model with Gradient Projection (GPPDM), which uses learnable class prototypes to condition the diffusion process and a gradient-projection mechanism on cross-attention to preserve old-task representations. Across CI and CIR benchmarks (CIFAR-100, ImageNet, CUB-200, CORe50), GPPDM achieves higher average accuracy and lower forgetting than strong baselines such as DDGR, while requiring memory roughly equivalent to storing one prototype per class. The approach yields higher fidelity replay data and demonstrates practical memory efficiency, providing a concrete framework for more robust continual learning with diffusion-based generative replay.

Abstract

Mitigating catastrophic forgetting is a key hurdle in continual learning. Deep Generative Replay (GR) provides techniques focused on generating samples from prior tasks to enhance the model's memory capabilities using generative AI models ranging from Generative Adversarial Networks (GANs) to the more recent Diffusion Models (DMs). A major issue is the deterioration in the quality of generated data compared to the original, as the generator continuously self-learns from its outputs. This degradation can lead to the potential risk of catastrophic forgetting (CF) occurring in the classifier. To address this, we propose the Gradient Projection Class-Prototype Conditional Diffusion Model (GPPDM), a GR-based approach for continual learning that enhances image quality in generators and thus reduces the CF in classifiers. The cornerstone of GPPDM is a learnable class prototype that captures the core characteristics of images in a given class. This prototype, integrated into the diffusion model's denoising process, ensures the generation of high-quality images of the old tasks, hence reducing the risk of CF in classifiers. Moreover, to further mitigate the CF of diffusion models, we propose a gradient projection technique tailored for the cross-attention layer of diffusion models to maximally maintain and preserve the representations of old task data in the current task as close as possible to their representations when they first arrived. Our empirical studies on diverse datasets demonstrate that our proposed method significantly outperforms existing state-of-the-art models, highlighting its satisfactory ability to preserve image quality and enhance the model's memory retention.
Paper Structure (36 sections, 12 equations, 8 figures, 7 tables, 2 algorithms)

This paper contains 36 sections, 12 equations, 8 figures, 7 tables, 2 algorithms.

Figures (8)

  • Figure 1: FID scores across tasks for generated images corresponding to the first task dataset of the baselines DDGR gao23e_dmcl and our proposed GPPDM. The illustration belongs to the label class "Apple", whereas the leftmost one is the real photo. As can be observed, for DDGR, the quality of generated apples deteriorates significantly across tasks. Consequently, the generated apples of DDGR in some last tasks become hardly recognizable, while the generated apples of our proposed GPPDM maintain their quality more efficiently. Finally, FID scores across tasks for our GPPDM are significantly better than DDGR, especially for later tasks when the gap becomes more pronounced.
  • Figure 1: Comparison of final average accuracy $A_{T} (\uparrow)$ and final average forgetting $F_{T} (\downarrow)$ in the setting of CIFAR-100 $(NC=5)$, AlexNet with different class-prototype initialization strategies relative to DDGR.
  • Figure 2: (A) GPPDM framework for CL with generative replay: During task $t$, the classifier $f^t_\phi$ trains on data from the current task $\mathcal{D}^t$ and on data from the previous task $\mathcal{M}^{1:t-1}$ generated by the diffusion model to reduce catastrophic forgetting. (B) Diffusion training with Gradient Projection: To maintain performance on images generated from previous tasks, the gradients of the diffusion loss $\mathcal{L}_d^t$ w.r.t. $W_K^t$ and $W_V^t$ are projected onto the subspace $\mathcal{S}^{t-1}$ and the orthogonal components, $\Delta W_K^t$ and $\Delta W_V^t$, are used to update $W_K^t$ and $W_V^t$ respectively. (C) Diffusion sampling process: The diffusion model begins by denoising the initial noise $\mathcal{N}(\mathbf{0}, \mathbf{I})$. Each denoising step incorporates class information from CLIP embeddings $\bm\tau(y)$ and the learned class prototype $\mathbf{c}(y)$. The class prototype is designed to capture the most distinctive features of a given class and assists the generator in synthesizing high-quality images.
  • Figure 2: Results on CIFAR-100, AlexNet model, while changing the number of diffusion timesteps.
  • Figure 3: Comparison of generated images from GPPDM and DDGR on CIFAR-100
  • ...and 3 more figures