Table of Contents
Fetching ...

Adaptive Memory Replay for Continual Learning

James Seale Smith, Lazar Valkov, Shaunak Halbe, Vyshnavi Gutta, Rogerio Feris, Zsolt Kira, Leonid Karlinsky

TL;DR

The paper tackles perpetual updating of Foundation Models by addressing catastrophic forgetting when memory is abundant but compute is limited during extended pre-training. It introduces adaptive memory replay, a bandit-based data-sampling framework using Boltzmann exploration to selectively replay forgotten past data conditioned on the current task, while keeping the compute cost on par with naive training. The approach is formalized through a forgetting-centric objective, cluster-based forgetting modeling, and a non-stationary K-armed bandit optimization that guides memory-buffer composition. Empirically, it achieves meaningful forgetting reductions (up to around 10%) on both vision and language pre-training tasks, with a zero-cost variant that matches or surpasses baselines under equivalent compute, demonstrating practical impact for scalable continual learning in large-scale models.

Abstract

Foundation Models (FMs) have become the hallmark of modern AI, however, these models are trained on massive data, leading to financially expensive training. Updating FMs as new data becomes available is important, however, can lead to `catastrophic forgetting', where models underperform on tasks related to data sub-populations observed too long ago. This continual learning (CL) phenomenon has been extensively studied, but primarily in a setting where only a small amount of past data can be stored. We advocate for the paradigm where memory is abundant, allowing us to keep all previous data, but computational resources are limited. In this setting, traditional replay-based CL approaches are outperformed by a simple baseline which replays past data selected uniformly at random, indicating that this setting necessitates a new approach. We address this by introducing a framework of adaptive memory replay for continual learning, where sampling of past data is phrased as a multi-armed bandit problem. We utilize Bolzmann sampling to derive a method which dynamically selects past data for training conditioned on the current task, assuming full data access and emphasizing training efficiency. Through extensive evaluations on both vision and language pre-training tasks, we demonstrate the effectiveness of our approach, which maintains high performance while reducing forgetting by up to 10% at no training efficiency cost.

Adaptive Memory Replay for Continual Learning

TL;DR

The paper tackles perpetual updating of Foundation Models by addressing catastrophic forgetting when memory is abundant but compute is limited during extended pre-training. It introduces adaptive memory replay, a bandit-based data-sampling framework using Boltzmann exploration to selectively replay forgotten past data conditioned on the current task, while keeping the compute cost on par with naive training. The approach is formalized through a forgetting-centric objective, cluster-based forgetting modeling, and a non-stationary K-armed bandit optimization that guides memory-buffer composition. Empirically, it achieves meaningful forgetting reductions (up to around 10%) on both vision and language pre-training tasks, with a zero-cost variant that matches or surpasses baselines under equivalent compute, demonstrating practical impact for scalable continual learning in large-scale models.

Abstract

Foundation Models (FMs) have become the hallmark of modern AI, however, these models are trained on massive data, leading to financially expensive training. Updating FMs as new data becomes available is important, however, can lead to `catastrophic forgetting', where models underperform on tasks related to data sub-populations observed too long ago. This continual learning (CL) phenomenon has been extensively studied, but primarily in a setting where only a small amount of past data can be stored. We advocate for the paradigm where memory is abundant, allowing us to keep all previous data, but computational resources are limited. In this setting, traditional replay-based CL approaches are outperformed by a simple baseline which replays past data selected uniformly at random, indicating that this setting necessitates a new approach. We address this by introducing a framework of adaptive memory replay for continual learning, where sampling of past data is phrased as a multi-armed bandit problem. We utilize Bolzmann sampling to derive a method which dynamically selects past data for training conditioned on the current task, assuming full data access and emphasizing training efficiency. Through extensive evaluations on both vision and language pre-training tasks, we demonstrate the effectiveness of our approach, which maintains high performance while reducing forgetting by up to 10% at no training efficiency cost.
Paper Structure (18 sections, 6 equations, 4 figures, 4 tables, 1 algorithm)

This paper contains 18 sections, 6 equations, 4 figures, 4 tables, 1 algorithm.

Figures (4)

  • Figure 1: Adaptive memory replay for continual pre-training. In our setting, we begin with general image pretraining (Task 0) and transition to learn different tasks (e.g., Real, Clipart, Painting, Sketch) with full memory access to all past task data. We choose relevant samples for model training, thereby minimizing catastrophic forgetting while efficiently updating the model with new data. Further, we replace current task data with selected past task data, hence not adding training cost.
  • Figure 2: Key difference of our work and prior work. Prior work assumes that memory is expensive and constrains replay data to a fixed budget. Our work assumes memory is cheap and stores all replay data in memory, focusing on how to dynamically select the most useful replay data for computation-budgeted replay.
  • Figure 3: Overview of our adaptive memory replay approach. New task data is integrated with selectively rehearsed old data from full replay memory to update the task model. Unlike simple iid replay, our rehearsal data is chosen through a combination of bandit estimation and Boltzmann sampling from clusters of old datasets stored in memory. To reduce computation costs associated with data replay, we randomly discard samples from the training data to be replaced with the selected replay data. This ensures a cost-effective balance between incorporating new information and retaining knowledge of previous tasks, thus mitigating catastrophic forgetting with minimal computational overhead.
  • Figure 4: Final Loss vs Training Time for adaptive memory replay vs Oracle using the Synthetic Visual Concepts syvic sequence.