Table of Contents
Fetching ...

Avoid Catastrophic Forgetting with Rank-1 Fisher from Diffusion Models

Zekun Wang, Anant Gupta, Zihan Dong, Christopher J. MacLellan

TL;DR

Catastrophic forgetting is a key challenge in continual learning. The authors show that diffusion models exhibit a nearly rank-1 Fisher in the low-SNR regime, with per-sample gradients aligned to the mean gradient, enabling a practical Rank-1 EWC penalty. They derive a concrete formulation for the Rank-1 penalty and combine it with generative distillation to stabilize rehearsal across tasks. On MNIST, Fashion-MNIST, CIFAR-10, and ImageNet-1k, the proposed Rank-1 EWC with distillation improves average FID and substantially reduces forgetting, outperforming replay-only and diagonal-Fisher baselines and nearly eliminating forgetting on simpler datasets. These findings highlight diffusion-model gradients as a source of efficient, high-signal curvature information that strengthens consolidation while preserving generation quality.

Abstract

Catastrophic forgetting remains a central obstacle for continual learning in neural models. Popular approaches -- replay and elastic weight consolidation (EWC) -- have limitations: replay requires a strong generator and is prone to distributional drift, while EWC implicitly assumes a shared optimum across tasks and typically uses a diagonal Fisher approximation. In this work, we study the gradient geometry of diffusion models, which can already produce high-quality replay data. We provide theoretical and empirical evidence that, in the low signal-to-noise ratio (SNR) regime, per-sample gradients become strongly collinear, yielding an empirical Fisher that is effectively rank-1 and aligned with the mean gradient. Leveraging this structure, we propose a rank-1 variant of EWC that is as cheap as the diagonal approximation yet captures the dominant curvature direction. We pair this penalty with a replay-based approach to encourage parameter sharing across tasks while mitigating drift. On class-incremental image generation datasets (MNIST, FashionMNIST, CIFAR-10, ImageNet-1k), our method consistently improves average FID and reduces forgetting relative to replay-only and diagonal-EWC baselines. In particular, forgetting is nearly eliminated on MNIST and FashionMNIST and is more than halved on ImageNet-1k. These results suggest that diffusion models admit an approximately rank-1 Fisher. With a better Fisher estimate, EWC becomes a strong complement to replay: replay encourages parameter sharing across tasks, while EWC effectively constrains replay-induced drift.

Avoid Catastrophic Forgetting with Rank-1 Fisher from Diffusion Models

TL;DR

Catastrophic forgetting is a key challenge in continual learning. The authors show that diffusion models exhibit a nearly rank-1 Fisher in the low-SNR regime, with per-sample gradients aligned to the mean gradient, enabling a practical Rank-1 EWC penalty. They derive a concrete formulation for the Rank-1 penalty and combine it with generative distillation to stabilize rehearsal across tasks. On MNIST, Fashion-MNIST, CIFAR-10, and ImageNet-1k, the proposed Rank-1 EWC with distillation improves average FID and substantially reduces forgetting, outperforming replay-only and diagonal-Fisher baselines and nearly eliminating forgetting on simpler datasets. These findings highlight diffusion-model gradients as a source of efficient, high-signal curvature information that strengthens consolidation while preserving generation quality.

Abstract

Catastrophic forgetting remains a central obstacle for continual learning in neural models. Popular approaches -- replay and elastic weight consolidation (EWC) -- have limitations: replay requires a strong generator and is prone to distributional drift, while EWC implicitly assumes a shared optimum across tasks and typically uses a diagonal Fisher approximation. In this work, we study the gradient geometry of diffusion models, which can already produce high-quality replay data. We provide theoretical and empirical evidence that, in the low signal-to-noise ratio (SNR) regime, per-sample gradients become strongly collinear, yielding an empirical Fisher that is effectively rank-1 and aligned with the mean gradient. Leveraging this structure, we propose a rank-1 variant of EWC that is as cheap as the diagonal approximation yet captures the dominant curvature direction. We pair this penalty with a replay-based approach to encourage parameter sharing across tasks while mitigating drift. On class-incremental image generation datasets (MNIST, FashionMNIST, CIFAR-10, ImageNet-1k), our method consistently improves average FID and reduces forgetting relative to replay-only and diagonal-EWC baselines. In particular, forgetting is nearly eliminated on MNIST and FashionMNIST and is more than halved on ImageNet-1k. These results suggest that diffusion models admit an approximately rank-1 Fisher. With a better Fisher estimate, EWC becomes a strong complement to replay: replay encourages parameter sharing across tasks, while EWC effectively constrains replay-induced drift.

Paper Structure

This paper contains 37 sections, 3 theorems, 17 equations, 14 figures, 6 tables.

Key Result

Proposition 1

Let $s_t^\star(x_t)$, $x_t\sim q_t$, be the score of the noisy data distribution at time $t$ in a variance-preserving diffusion process. As SNR decreases, $s_t^\star(x_t)\approx -x_t/(1-\bar{\alpha_t})$.

Figures (14)

  • Figure 1: MSE between model input $x_t$ and the scaled prediction $\hat{x_t}$ at each timestep.
  • Figure 2: Absolute cosine similarities between per-sample gradient $g(\theta;x_t)$ and their expectation $\mu(\theta)$ at different diffusion timesteps. Each pixel represents a per-sample similarity. Higher values (deeper red) indicate stronger collinearity with $\mu(\theta)$.
  • Figure 3: (a): Pairwise cosine similarities of $\mu_t(\theta)$ across each forward process timestep. (b): Top 5 eigenvalues of $F_t(\theta)$ across timesteps in log-scale. (c): The ratio $r_t=\lambda_2/\lambda_1$ across timesteps. (d): Relative Frobenius norm between $F_t(\theta)$ and diagonal and rank-1 approximations across timesteps.
  • Figure 4: Average FID at each task during continual learning on evaluated datasets. Standard errors are averaged over 3 random seeds.
  • Figure 5: Examples of generated images from selected classes in ImageNet-1k over continual learning tasks. (a) Hornbill class sampled from models trained on task 1 to 19. (b) Ruffed grouse class from task 1 to 19. (c) Convertible class from task 10 to 19. (d) Digital watch class from task 10 to 19. Top row: generative distillation-only; middle row: diagonal; bottom row: rank-1.
  • ...and 9 more figures

Theorems & Definitions (6)

  • Proposition 1
  • Proposition 2
  • Theorem 1
  • proof
  • proof
  • proof