MGSER-SAM: Memory-Guided Soft Experience Replay with Sharpness-Aware Optimization for Enhanced Continual Learning
Xingyu Li, Bo Tang
TL;DR
Catastrophic forgetting in continual learning is mitigated by integrating sharpness-aware optimization with memory replay. The paper presents ER-SAM and MGSER-SAM, where SAM is applied to replay objectives and, in the latter, augmented with soft logits to align memory gradients, effectively minimizing $L_{total}=L_t+\\hat{L}_s$ while promoting flat minima. Empirical results across task-, class-, and domain-incremental benchmarks show substantial improvements over ER and DER++ and reveal stronger forgetting control as memory capacity increases. The approach demonstrates practical impact by delivering higher accuracy and more stable performance in sequential learning tasks without extensive architectural changes.
Abstract
Deep neural networks suffer from the catastrophic forgetting problem in the field of continual learning (CL). To address this challenge, we propose MGSER-SAM, a novel memory replay-based algorithm specifically engineered to enhance the generalization capabilities of CL models. We first intergrate the SAM optimizer, a component designed for optimizing flatness, which seamlessly fits into well-known Experience Replay frameworks such as ER and DER++. Then, MGSER-SAM distinctively addresses the complex challenge of reconciling conflicts in weight perturbation directions between ongoing tasks and previously stored memories, which is underexplored in the SAM optimizer. This is effectively accomplished by the strategic integration of soft logits and the alignment of memory gradient directions, where the regularization terms facilitate the concurrent minimization of various training loss terms integral to the CL process. Through rigorous experimental analysis conducted across multiple benchmarks, MGSER-SAM has demonstrated a consistent ability to outperform existing baselines in all three CL scenarios. Comparing to the representative memory replay-based baselines ER and DER++, MGSER-SAM not only improves the testing accuracy by $24.4\%$ and $17.6\%$ respectively, but also achieves the lowest forgetting on each benchmark.
