Table of Contents
Fetching ...

Layerwise Proximal Replay: A Proximal Point Method for Online Continual Learning

Jason Yoo, Yunpeng Liu, Frank Wood, Geoff Pleiss

TL;DR

This paper addresses instability in replay-based online continual learning by introducing Layerwise Proximal Replay (LPR), a proximal-point style optimizer that applies a layer-specific preconditioner to gradient updates. By constraining changes to hidden activations of replay data through a per-layer preconditioner $P_\ell$, LPR balances learning from new and replay data while stabilizing optimization, and it can be used with a variety of replay losses. Empirical results across memory-constrained and memory-unconstrained settings on Split-CIFAR100, Split-TinyImageNet, and Online CLEAR show consistent gains in final accuracy, average anytime accuracy, and worst-case accuracy, along with reduced representation drift and more stable gradients. The findings suggest that optimization geometry, not just memory of past data, plays a crucial role in online continual learning and that proximal-gradient style updates can meaningfully improve replay-based methods with modest computational overhead.

Abstract

In online continual learning, a neural network incrementally learns from a non-i.i.d. data stream. Nearly all online continual learning methods employ experience replay to simultaneously prevent catastrophic forgetting and underfitting on past data. Our work demonstrates a limitation of this approach: neural networks trained with experience replay tend to have unstable optimization trajectories, impeding their overall accuracy. Surprisingly, these instabilities persist even when the replay buffer stores all previous training examples, suggesting that this issue is orthogonal to catastrophic forgetting. We minimize these instabilities through a simple modification of the optimization geometry. Our solution, Layerwise Proximal Replay (LPR), balances learning from new and replay data while only allowing for gradual changes in the hidden activation of past data. We demonstrate that LPR consistently improves replay-based online continual learning methods across multiple problem settings, regardless of the amount of available replay memory.

Layerwise Proximal Replay: A Proximal Point Method for Online Continual Learning

TL;DR

This paper addresses instability in replay-based online continual learning by introducing Layerwise Proximal Replay (LPR), a proximal-point style optimizer that applies a layer-specific preconditioner to gradient updates. By constraining changes to hidden activations of replay data through a per-layer preconditioner , LPR balances learning from new and replay data while stabilizing optimization, and it can be used with a variety of replay losses. Empirical results across memory-constrained and memory-unconstrained settings on Split-CIFAR100, Split-TinyImageNet, and Online CLEAR show consistent gains in final accuracy, average anytime accuracy, and worst-case accuracy, along with reduced representation drift and more stable gradients. The findings suggest that optimization geometry, not just memory of past data, plays a crucial role in online continual learning and that proximal-gradient style updates can meaningfully improve replay-based methods with modest computational overhead.

Abstract

In online continual learning, a neural network incrementally learns from a non-i.i.d. data stream. Nearly all online continual learning methods employ experience replay to simultaneously prevent catastrophic forgetting and underfitting on past data. Our work demonstrates a limitation of this approach: neural networks trained with experience replay tend to have unstable optimization trajectories, impeding their overall accuracy. Surprisingly, these instabilities persist even when the replay buffer stores all previous training examples, suggesting that this issue is orthogonal to catastrophic forgetting. We minimize these instabilities through a simple modification of the optimization geometry. Our solution, Layerwise Proximal Replay (LPR), balances learning from new and replay data while only allowing for gradual changes in the hidden activation of past data. We demonstrate that LPR consistently improves replay-based online continual learning methods across multiple problem settings, regardless of the amount of available replay memory.
Paper Structure (35 sections, 29 equations, 8 figures, 10 tables, 1 algorithm)

This paper contains 35 sections, 29 equations, 8 figures, 10 tables, 1 algorithm.

Figures (8)

  • Figure 1: Internal representation and accuracy metrics vs entire training iterations for Split-CIFAR100's task 1 test data. Results were computed across 5 seeds. All LPR (ER) runs employ a learning rate of 0.1. We mark task boundaries in the right sub-figure using vertical lines. Left: dynamics of internal representation changes of task 1 data over the course of training. Middle: total variation of task 1 test accuracy. Right: task 1 test accuracy. We observe that LPR obtains lower representation drift and lower total variation of accuracy which demonstrate that LPR better preserves predictive stability of past data. This property is correlated to the overall higher accuracy as shown in the rightmost plot.
  • Figure 2: Ratios between post-preconditioning gradient norms and original gradient norms associated with current and replay data loss vs training iterations for LPR augmented experience replay. The vertical lines denote Split-CIFAR100's task boundaries.
  • Figure 3: Holdout set optimization metrics for Split-CIFAR100 experiments with memory size 2000. Shading denotes the minimum and maximum values observed across 5 seeds.
  • Figure 4: Holdout set optimization metrics for Split-TinyImageNet experiments with memory size 4000. Shading denotes the minimum and maximum values observed across 5 seeds.
  • Figure 5: Holdout set optimization metrics for Online CLEAR experiments with memory size 2000. Shading denotes the minimum and maximum values observed across 5 seeds.
  • ...and 3 more figures

Theorems & Definitions (1)

  • proof