Avoid Catastrophic Forgetting with Rank-1 Fisher from Diffusion Models
Zekun Wang, Anant Gupta, Zihan Dong, Christopher J. MacLellan
TL;DR
Catastrophic forgetting is a key challenge in continual learning. The authors show that diffusion models exhibit a nearly rank-1 Fisher in the low-SNR regime, with per-sample gradients aligned to the mean gradient, enabling a practical Rank-1 EWC penalty. They derive a concrete formulation for the Rank-1 penalty and combine it with generative distillation to stabilize rehearsal across tasks. On MNIST, Fashion-MNIST, CIFAR-10, and ImageNet-1k, the proposed Rank-1 EWC with distillation improves average FID and substantially reduces forgetting, outperforming replay-only and diagonal-Fisher baselines and nearly eliminating forgetting on simpler datasets. These findings highlight diffusion-model gradients as a source of efficient, high-signal curvature information that strengthens consolidation while preserving generation quality.
Abstract
Catastrophic forgetting remains a central obstacle for continual learning in neural models. Popular approaches -- replay and elastic weight consolidation (EWC) -- have limitations: replay requires a strong generator and is prone to distributional drift, while EWC implicitly assumes a shared optimum across tasks and typically uses a diagonal Fisher approximation. In this work, we study the gradient geometry of diffusion models, which can already produce high-quality replay data. We provide theoretical and empirical evidence that, in the low signal-to-noise ratio (SNR) regime, per-sample gradients become strongly collinear, yielding an empirical Fisher that is effectively rank-1 and aligned with the mean gradient. Leveraging this structure, we propose a rank-1 variant of EWC that is as cheap as the diagonal approximation yet captures the dominant curvature direction. We pair this penalty with a replay-based approach to encourage parameter sharing across tasks while mitigating drift. On class-incremental image generation datasets (MNIST, FashionMNIST, CIFAR-10, ImageNet-1k), our method consistently improves average FID and reduces forgetting relative to replay-only and diagonal-EWC baselines. In particular, forgetting is nearly eliminated on MNIST and FashionMNIST and is more than halved on ImageNet-1k. These results suggest that diffusion models admit an approximately rank-1 Fisher. With a better Fisher estimate, EWC becomes a strong complement to replay: replay encourages parameter sharing across tasks, while EWC effectively constrains replay-induced drift.
