On the Theory of Continual Learning with Gradient Descent for Neural Networks
Hossein Taheri, Avishek Ghosh, Arya Mazumdar
TL;DR
This paper analyzes continual learning under gradient descent in a tractable setting: sequentially trained, independent XOR-cluster tasks using a two-layer neural network with a quadratic activation. It decomposes test-time forgetting into train-time forgetting and a delayed generalization gap, and derives algorithmic-stability-based bounds that are then extended from the infinite-width regime to finite width via concentration arguments. The authors establish closed-form bounds on train- and test-time forgetting that depend on the dimension $d$, the number of tasks $K$, the per-task sample size $n$, the hidden width $m$, and the number of GD iterations $T$, identifying scaling regimes such as $n = ilde{\Theta}(d^2K)$, $m = ilde{\Theta}(d^8K^4)$, and $T = ilde{\Theta}(d^2)$ under which forgetting vanishes. Empirical results on diverse settings corroborate the theory and demonstrate the robustness of the findings beyond the analyzed XOR model, including various activations and data distributions.
Abstract
Continual learning, the ability of a model to adapt to an ongoing sequence of tasks without forgetting the earlier ones, is a central goal of artificial intelligence. To shed light on its underlying mechanisms, we analyze the limitations of continual learning in a tractable yet representative setting. In particular, we study one-hidden-layer quadratic neural networks trained by gradient descent on an XOR cluster dataset with Gaussian noise, where different tasks correspond to different clusters with orthogonal means. Our results obtain bounds on the rate of forgetting during train and test-time in terms of the number of iterations, the sample size, the number of tasks, and the hidden-layer size. Our results reveal interesting phenomena on the role of different problem parameters in the rate of forgetting. Numerical experiments across diverse setups confirm our results, demonstrating their validity beyond the analyzed settings.
