Layerwise Proximal Replay: A Proximal Point Method for Online Continual Learning
Jason Yoo, Yunpeng Liu, Frank Wood, Geoff Pleiss
TL;DR
This paper addresses instability in replay-based online continual learning by introducing Layerwise Proximal Replay (LPR), a proximal-point style optimizer that applies a layer-specific preconditioner to gradient updates. By constraining changes to hidden activations of replay data through a per-layer preconditioner $P_\ell$, LPR balances learning from new and replay data while stabilizing optimization, and it can be used with a variety of replay losses. Empirical results across memory-constrained and memory-unconstrained settings on Split-CIFAR100, Split-TinyImageNet, and Online CLEAR show consistent gains in final accuracy, average anytime accuracy, and worst-case accuracy, along with reduced representation drift and more stable gradients. The findings suggest that optimization geometry, not just memory of past data, plays a crucial role in online continual learning and that proximal-gradient style updates can meaningfully improve replay-based methods with modest computational overhead.
Abstract
In online continual learning, a neural network incrementally learns from a non-i.i.d. data stream. Nearly all online continual learning methods employ experience replay to simultaneously prevent catastrophic forgetting and underfitting on past data. Our work demonstrates a limitation of this approach: neural networks trained with experience replay tend to have unstable optimization trajectories, impeding their overall accuracy. Surprisingly, these instabilities persist even when the replay buffer stores all previous training examples, suggesting that this issue is orthogonal to catastrophic forgetting. We minimize these instabilities through a simple modification of the optimization geometry. Our solution, Layerwise Proximal Replay (LPR), balances learning from new and replay data while only allowing for gradual changes in the hidden activation of past data. We demonstrate that LPR consistently improves replay-based online continual learning methods across multiple problem settings, regardless of the amount of available replay memory.
