Table of Contents
Fetching ...

Overcoming the Stability Gap in Continual Learning

Md Yousuf Harun, Christopher Kanan

TL;DR

This work targets model decay in industry by applying continual learning to large pre-trained models and identifying a stability gap that causes transient forgetting of old tasks when new data arrive. The authors propose Stability Gap Mitigation (SGM), a composite approach combining data-driven output-layer initialization, dynamic soft targets, LoRA-based limits on hidden-layer plasticity, and old-output-class freezing, achieving large reductions in stability, plasticity, and continual knowledge gaps. Across class-incremental and IID data streams, SGM dramatically improves learning efficiency, delivering up to 16.7× fewer network updates and 31.9× fewer TFLOPs versus joint training, while maintaining or exceeding the joint upper bound in several settings. The method generalizes across backbones and CL variants, including storage-constrained offline/online CL and non-rehearsal baselines, highlighting its practical potential for production systems to curb model decay with substantial compute and energy savings.

Abstract

Pre-trained deep neural networks (DNNs) are being widely deployed by industry for making business decisions and to serve users; however, a major problem is model decay, where the DNN's predictions become more erroneous over time, resulting in revenue loss or unhappy users. To mitigate model decay, DNNs are retrained from scratch using old and new data. This is computationally expensive, so retraining happens only once performance significantly decreases. Here, we study how continual learning (CL) could potentially overcome model decay in large pre-trained DNNs and greatly reduce computational costs for keeping DNNs up-to-date. We identify the "stability gap" as a major obstacle in our setting. The stability gap refers to a phenomenon where learning new data causes large drops in performance for past tasks before CL mitigation methods eventually compensate for this drop. We test two hypotheses to investigate the factors influencing the stability gap and identify a method that vastly reduces this gap. In large-scale experiments for both easy and hard CL distributions (e.g., class incremental learning), we demonstrate that our method reduces the stability gap and greatly increases computational efficiency. Our work aligns CL with the goals of the production setting, where CL is needed for many applications.

Overcoming the Stability Gap in Continual Learning

TL;DR

This work targets model decay in industry by applying continual learning to large pre-trained models and identifying a stability gap that causes transient forgetting of old tasks when new data arrive. The authors propose Stability Gap Mitigation (SGM), a composite approach combining data-driven output-layer initialization, dynamic soft targets, LoRA-based limits on hidden-layer plasticity, and old-output-class freezing, achieving large reductions in stability, plasticity, and continual knowledge gaps. Across class-incremental and IID data streams, SGM dramatically improves learning efficiency, delivering up to 16.7× fewer network updates and 31.9× fewer TFLOPs versus joint training, while maintaining or exceeding the joint upper bound in several settings. The method generalizes across backbones and CL variants, including storage-constrained offline/online CL and non-rehearsal baselines, highlighting its practical potential for production systems to curb model decay with substantial compute and energy savings.

Abstract

Pre-trained deep neural networks (DNNs) are being widely deployed by industry for making business decisions and to serve users; however, a major problem is model decay, where the DNN's predictions become more erroneous over time, resulting in revenue loss or unhappy users. To mitigate model decay, DNNs are retrained from scratch using old and new data. This is computationally expensive, so retraining happens only once performance significantly decreases. Here, we study how continual learning (CL) could potentially overcome model decay in large pre-trained DNNs and greatly reduce computational costs for keeping DNNs up-to-date. We identify the "stability gap" as a major obstacle in our setting. The stability gap refers to a phenomenon where learning new data causes large drops in performance for past tasks before CL mitigation methods eventually compensate for this drop. We test two hypotheses to investigate the factors influencing the stability gap and identify a method that vastly reduces this gap. In large-scale experiments for both easy and hard CL distributions (e.g., class incremental learning), we demonstrate that our method reduces the stability gap and greatly increases computational efficiency. Our work aligns CL with the goals of the production setting, where CL is needed for many applications.
Paper Structure (35 sections, 6 equations, 9 figures, 11 tables)

This paper contains 35 sections, 6 equations, 9 figures, 11 tables.

Figures (9)

  • Figure 1: An overview of stability gap phenomenon. The stability gap is a phenomenon that occurs in CL when learning new data, where accuracy on previously learned data (Y-axis) drops significantly as a function of training iterations when a new distribution is introduced (X-axis). Fig.(a) illustrates this behavior during CIL, where a network pre-trained on ImageNet-1K, learns 365 new classes from Places365-LT over five rehearsal sessions. Each rehearsal session involves 600 iterations that combine samples from the old and new tasks. A gray dotted vertical line marks the end of a rehearsal session or a task transition. When rehearsal begins, accuracy on the old task for the conventional rehearsal drops dramatically before slowly recovering, although it fails to recover the original performance on the old data. The traditional measures of catastrophic forgetting focus on performance at task transitions (red diamonds), ignoring significant forgetting that occurs during the learning process between task transitions. Fig.(b) shows the stability gap in the learning curve averaged over five rehearsal sessions. In this work, we attempt to mitigate the stability gap.
  • Figure 2: Mitigation methods averaged over 5 rehearsal sessions during CIL. (a) The loss on new classes when only training the output layer, which reveals soft targets and data-driven weight initialization greatly reduce the initial loss. (b) Accuracy on ImageNet-1K for hard vs. soft targets, which shows that soft targets reduce the stability gap. (c) Network plasticity increases the stability gap.
  • Figure 3: Speed of acquiring new knowledge. SGM requires fewer updates and TFLOPs than vanilla to reach 99% of the best accuracy on new classes (highlighted).
  • Figure 4: Stability gap over all rehearsal sessions. After pre-training on ImageNet-1K, the model learns 365 new classes from Places365-LT over five rehearsal sessions (600 iterations per rehearsal session). SGM quickly recovers old performance at the beginning of CL whereas vanilla fails to obtain full recovery. After each rehearsal session (vertical dotted gray line), the final top-1 accuracy (%) is highlighted by diamond (SGM), star (joint model), and square (vanilla). The joint model (upper bound) is jointly trained on ImageNet and seen CL batches from the Places dataset.
  • Figure 5: Computational efficiency. Our method, SGM, provides a $16.7\times$ speedup in the number of network updates and a $31.9 \times$ speedup in TFLOPs compared to a joint model (upper bound) with the combined 1365 class dataset (ImageNet-1K and Places365-LT combined). For SGM and conventional rehearsal, we show the stability gap in the learning curve averaged over rehearsal sessions.
  • ...and 4 more figures