Table of Contents
Fetching ...

On the Theory of Continual Learning with Gradient Descent for Neural Networks

Hossein Taheri, Avishek Ghosh, Arya Mazumdar

TL;DR

This paper analyzes continual learning under gradient descent in a tractable setting: sequentially trained, independent XOR-cluster tasks using a two-layer neural network with a quadratic activation. It decomposes test-time forgetting into train-time forgetting and a delayed generalization gap, and derives algorithmic-stability-based bounds that are then extended from the infinite-width regime to finite width via concentration arguments. The authors establish closed-form bounds on train- and test-time forgetting that depend on the dimension $d$, the number of tasks $K$, the per-task sample size $n$, the hidden width $m$, and the number of GD iterations $T$, identifying scaling regimes such as $n = ilde{\Theta}(d^2K)$, $m = ilde{\Theta}(d^8K^4)$, and $T = ilde{\Theta}(d^2)$ under which forgetting vanishes. Empirical results on diverse settings corroborate the theory and demonstrate the robustness of the findings beyond the analyzed XOR model, including various activations and data distributions.

Abstract

Continual learning, the ability of a model to adapt to an ongoing sequence of tasks without forgetting the earlier ones, is a central goal of artificial intelligence. To shed light on its underlying mechanisms, we analyze the limitations of continual learning in a tractable yet representative setting. In particular, we study one-hidden-layer quadratic neural networks trained by gradient descent on an XOR cluster dataset with Gaussian noise, where different tasks correspond to different clusters with orthogonal means. Our results obtain bounds on the rate of forgetting during train and test-time in terms of the number of iterations, the sample size, the number of tasks, and the hidden-layer size. Our results reveal interesting phenomena on the role of different problem parameters in the rate of forgetting. Numerical experiments across diverse setups confirm our results, demonstrating their validity beyond the analyzed settings.

On the Theory of Continual Learning with Gradient Descent for Neural Networks

TL;DR

This paper analyzes continual learning under gradient descent in a tractable setting: sequentially trained, independent XOR-cluster tasks using a two-layer neural network with a quadratic activation. It decomposes test-time forgetting into train-time forgetting and a delayed generalization gap, and derives algorithmic-stability-based bounds that are then extended from the infinite-width regime to finite width via concentration arguments. The authors establish closed-form bounds on train- and test-time forgetting that depend on the dimension , the number of tasks , the per-task sample size , the hidden width , and the number of GD iterations , identifying scaling regimes such as , , and under which forgetting vanishes. Empirical results on diverse settings corroborate the theory and demonstrate the robustness of the findings beyond the analyzed XOR model, including various activations and data distributions.

Abstract

Continual learning, the ability of a model to adapt to an ongoing sequence of tasks without forgetting the earlier ones, is a central goal of artificial intelligence. To shed light on its underlying mechanisms, we analyze the limitations of continual learning in a tractable yet representative setting. In particular, we study one-hidden-layer quadratic neural networks trained by gradient descent on an XOR cluster dataset with Gaussian noise, where different tasks correspond to different clusters with orthogonal means. Our results obtain bounds on the rate of forgetting during train and test-time in terms of the number of iterations, the sample size, the number of tasks, and the hidden-layer size. Our results reveal interesting phenomena on the role of different problem parameters in the rate of forgetting. Numerical experiments across diverse setups confirm our results, demonstrating their validity beyond the analyzed settings.

Paper Structure

This paper contains 36 sections, 13 theorems, 105 equations, 12 figures, 1 algorithm.

Key Result

Theorem 1

Consider the $d$-dimensional XOR cluster dataset with $K$ tasks and assume gradient descent with $\eta T =\Theta (d^2)$ iterations and $n=\widetilde{\Theta}(d^2K)$ samples for each subsequent task trained by a neural net with $m=\widetilde{\Omega}(d^8K^4)$ hidden neurons. Then, with high probability where $\widetilde{O}(\cdot)$ hides logarithmic factors in $n,T$ and $1/\delta.$

Figures (12)

  • Figure 1: Classification test error for each task vs iterations for the XOR cluster with $K=3$ tasks trained on a quadratic network with $n=2500$(left) and $n=5000$(right) training samples per task.
  • Figure 2: Classification train error for each task vs iterations for the XOR cluster with $K=3$ tasks trained on a quadratic network. We fix $n=2500$ for the first task and increase the sample size of second and third tasks across figures. Increasing the sample-size stabilizes per-task training and decreases forgetting for previous tasks.
  • Figure 3: We repeat the experiment from Figure \ref{['fig:4']}, this time using GELU activation and logistic loss function, demonstrating that our findings remain valid across different settings.
  • Figure 4: Left: Training loss of task 1 versus task index (i.e., $\widehat{F}_1(w_k)$ as a function of $k$) for $K=6$ tasks for different sample-sizes and training horizons per task. Right: Training loss per task ($\widehat{F}_k(w_k^{(t)})$) versus iteration when $n=2000,T=4000$ for each task. We use GELU activation and logistic loss. While each task individually attains near-zero training loss, the training loss for the first task grows with both the number of tasks ($K$) and the number of iterations ($T$).
  • Figure 5: Impact of network width ($m$) on the test error for learning the XOR cluster distribution with 3 tasks with quadratic networks. Increasing width helps with continual learning, however the benefits diminish as $m$ grows.
  • ...and 7 more figures

Theorems & Definitions (19)

  • Theorem 1: Train-time forgetting
  • Theorem 2: Train error in continual learning
  • Theorem 3: Delayed generalization gap
  • Remark 1: Test-time forgetting
  • Theorem 4: Improved gen. gap
  • Remark 2
  • Proposition 1: Regularized continual learning
  • Proposition 2: Single-task XOR
  • proof
  • Theorem 5: Restatement of Theorem \ref{['thm:gen_gap']}
  • ...and 9 more