On the Implicit Adversariality of Catastrophic Forgetting in Deep Continual Learning
Ze Peng, Jian Zhang, Jintao Guo, Lei Qi, Yang Gao, Yinghuan Shi
TL;DR
This work reveals that catastrophic forgetting in deep continual learning is driven by an implicit adversariality where new-task updates align with old-task high-curvature directions. Depth and a low-rank bias in old-task weights funnel forward and backward propagations into a shared low-dimensional subspace, enabling persistent alignment and rapid forgetting. Gradient Projection methods mitigate forward alignment but leave backward alignment unaddressed; the authors propose backGP to constrain backward updates, achieving substantial improvements across standard CL benchmarks and further gains when combined with plasticity-enhancing regularizers. The findings connect continual learning with adversarial robustness and offer practical strategies and theoretical insight for transfer learning and foundation-model fine-tuning scenarios.
Abstract
Continual learning seeks the human-like ability to accumulate new skills in machine intelligence. Its central challenge is catastrophic forgetting, whose underlying cause has not been fully understood for deep networks. In this paper, we demystify catastrophic forgetting by revealing that the new-task training is implicitly an adversarial attack against the old-task knowledge. Specifically, the new-task gradients automatically and accurately align with the sharp directions of the old-task loss landscape, rapidly increasing the old-task loss. This adversarial alignment is intriguingly counter-intuitive because the sharp directions are too sparsely distributed to align with by chance. To understand it, we theoretically show that it arises from training's low-rank bias, which, through forward and backward propagation, confines the two directions into the same low-dimensional subspace, facilitating alignment. Gradient projection (GP) methods, a representative family of forgetting-mitigating methods, reduce adversarial alignment caused by forward propagation, but cannot address the alignment due to backward propagation. We propose backGP to address it, which reduces forgetting by 10.8% and improves accuracy by 12.7% on average over GP methods.
