Table of Contents
Fetching ...

STAR: Stability-Inducing Weight Perturbation for Continual Learning

Masih Eskandar, Tooba Imtiaz, Davin Hill, Zifeng Wang, Jennifer Dy

TL;DR

STAR addresses catastrophic forgetting in continual learning by enforcing stability of the output distribution in a local parameter neighborhood through a worst-case perturbation objective. It is a plug-and-play regularizer that can be added to any rehearsal-based CL baseline, optimizing a KL-divergence-based forgetting surrogate under local perturbations. Empirically, STAR yields up to 15% absolute improvements across multiple baselines and datasets, with ablations confirming the effectiveness of gradient-based perturbations and buffer-focused data, and comparisons against other rehearsal enhancements show competitive gains. This work reframes CL stability around local parameter-space behavior, enabling more robust future updates without requiring task boundaries.

Abstract

Humans can naturally learn new and varying tasks in a sequential manner. Continual learning is a class of learning algorithms that updates its learned model as it sees new data (on potentially new tasks) in a sequence. A key challenge in continual learning is that as the model is updated to learn new tasks, it becomes susceptible to catastrophic forgetting, where knowledge of previously learned tasks is lost. A popular approach to mitigate forgetting during continual learning is to maintain a small buffer of previously-seen samples and to replay them during training. However, this approach is limited by the small buffer size, and while forgetting is reduced, it is still present. In this paper, we propose a novel loss function, STAR, that exploits the worst-case parameter perturbation that reduces the KL-divergence of model predictions with that of its local parameter neighborhood to promote stability and alleviate forgetting. STAR can be combined with almost any existing rehearsal-based method as a plug-and-play component. We empirically show that STAR consistently improves the performance of existing methods by up to 15% across varying baselines and achieves superior or competitive accuracy to that of state-of-the-art methods aimed at improving rehearsal-based continual learning.

STAR: Stability-Inducing Weight Perturbation for Continual Learning

TL;DR

STAR addresses catastrophic forgetting in continual learning by enforcing stability of the output distribution in a local parameter neighborhood through a worst-case perturbation objective. It is a plug-and-play regularizer that can be added to any rehearsal-based CL baseline, optimizing a KL-divergence-based forgetting surrogate under local perturbations. Empirically, STAR yields up to 15% absolute improvements across multiple baselines and datasets, with ablations confirming the effectiveness of gradient-based perturbations and buffer-focused data, and comparisons against other rehearsal enhancements show competitive gains. This work reframes CL stability around local parameter-space behavior, enabling more robust future updates without requiring task boundaries.

Abstract

Humans can naturally learn new and varying tasks in a sequential manner. Continual learning is a class of learning algorithms that updates its learned model as it sees new data (on potentially new tasks) in a sequence. A key challenge in continual learning is that as the model is updated to learn new tasks, it becomes susceptible to catastrophic forgetting, where knowledge of previously learned tasks is lost. A popular approach to mitigate forgetting during continual learning is to maintain a small buffer of previously-seen samples and to replay them during training. However, this approach is limited by the small buffer size, and while forgetting is reduced, it is still present. In this paper, we propose a novel loss function, STAR, that exploits the worst-case parameter perturbation that reduces the KL-divergence of model predictions with that of its local parameter neighborhood to promote stability and alleviate forgetting. STAR can be combined with almost any existing rehearsal-based method as a plug-and-play component. We empirically show that STAR consistently improves the performance of existing methods by up to 15% across varying baselines and achieves superior or competitive accuracy to that of state-of-the-art methods aimed at improving rehearsal-based continual learning.

Paper Structure

This paper contains 32 sections, 12 equations, 6 figures, 8 tables, 1 algorithm.

Figures (6)

  • Figure 1: STAR improves rehearsal-based CL by considering the change in output distribution in a local parameter neighborhood. Our regularization objective promotes convergence to regions of the parameter space where the loss is stable over a local neighborhood $\delta$ around the parameters.
  • Figure 2: Conceptual illustration of STAR. Vanilla rehearsal-based CL approaches do not consider consistency of output distribution in the local parameter neighborhood (a). By leveraging the gradient of our proposed loss $\mathcal{L}_{STAR}$ (b), we navigate towards parameter regions where the model prediction of previously seen data is consistent (c), reducing forgetting during training due to future parameter updates leading to lower divergence of model output distribution (d).
  • Figure 3: KL Divergence of the correctly classified samples of the test set of each task for the S-CIFAR10 Dataset, between model predictions immediately after training on that task, and future parameter points during training. "Worst" indicates the KL Divergence after multiple gradient ascent steps as detailed in \ref{['sec:gradaprox']}.
  • Figure 3: Effect of method components on average accuracy. The second column represents whether only correctly classified samples ($x^*$) or all samples ($x$) were used in STAR, the third column represents whether the weight perturbation for calculating the loss is $\nabla\mathcal{L}_{FG}$ or random noise ($z$). All other hyper-parameters are fixed.
  • Figure 4: Effect of Number of Gradient Steps in the Maximization Objective on Performance as Measured in Average Accuracy
  • ...and 1 more figures