Table of Contents
Fetching ...

Theoretical Insights into Overparameterized Models in Multi-Task and Replay-Based Continual Learning

Amin Banayeeanzade, Mahdi Soltanolkotabi, Mohammad Rostami

TL;DR

This work provides exact, non-asymptotic theory for overparameterized multi-task learning and replay-based continual learning using linear models as a tractable proxy. It derives closed-form expressions for average generalization error and knowledge transfer across tasks, and characterizes forgetting under replay buffers, revealing how model size $p$, data size $ar{n}$, and task similarity govern outcomes. The results show that interpolation thresholds shift with multi-task settings and that task similarity can both facilitate knowledge transfer and create interference, depending on capacity. Empirically, the authors validate that deep networks exhibit similar double-descent–like behavior and buffer-dependent forgetting, suggesting that the linear theory offers practical guidance for MTL and CL design in real-world DNNs.

Abstract

Multi-task learning (MTL) is a machine learning paradigm that aims to improve the generalization performance of a model on multiple related tasks by training it simultaneously on those tasks. Unlike MTL, where the model has instant access to the training data of all tasks, continual learning (CL) involves adapting to new sequentially arriving tasks over time without forgetting the previously acquired knowledge. Despite the wide practical adoption of CL and MTL and extensive literature on both areas, there remains a gap in the theoretical understanding of these methods when used with overparameterized models such as deep neural networks. This paper studies the overparameterized linear models as a proxy for more complex models. We develop theoretical results describing the effect of various system parameters on the model's performance in an MTL setup. Specifically, we study the impact of model size, dataset size, and task similarity on the generalization error and knowledge transfer. Additionally, we present theoretical results to characterize the performance of replay-based CL models. Our results reveal the impact of buffer size and model capacity on the forgetting rate in a CL setup and help shed light on some of the state-of-the-art CL methods. Finally, through extensive empirical evaluations, we demonstrate that our theoretical findings are also applicable to deep neural networks, offering valuable guidance for designing MTL and CL models in practice.

Theoretical Insights into Overparameterized Models in Multi-Task and Replay-Based Continual Learning

TL;DR

This work provides exact, non-asymptotic theory for overparameterized multi-task learning and replay-based continual learning using linear models as a tractable proxy. It derives closed-form expressions for average generalization error and knowledge transfer across tasks, and characterizes forgetting under replay buffers, revealing how model size , data size , and task similarity govern outcomes. The results show that interpolation thresholds shift with multi-task settings and that task similarity can both facilitate knowledge transfer and create interference, depending on capacity. Empirically, the authors validate that deep networks exhibit similar double-descent–like behavior and buffer-dependent forgetting, suggesting that the linear theory offers practical guidance for MTL and CL design in real-world DNNs.

Abstract

Multi-task learning (MTL) is a machine learning paradigm that aims to improve the generalization performance of a model on multiple related tasks by training it simultaneously on those tasks. Unlike MTL, where the model has instant access to the training data of all tasks, continual learning (CL) involves adapting to new sequentially arriving tasks over time without forgetting the previously acquired knowledge. Despite the wide practical adoption of CL and MTL and extensive literature on both areas, there remains a gap in the theoretical understanding of these methods when used with overparameterized models such as deep neural networks. This paper studies the overparameterized linear models as a proxy for more complex models. We develop theoretical results describing the effect of various system parameters on the model's performance in an MTL setup. Specifically, we study the impact of model size, dataset size, and task similarity on the generalization error and knowledge transfer. Additionally, we present theoretical results to characterize the performance of replay-based CL models. Our results reveal the impact of buffer size and model capacity on the forgetting rate in a CL setup and help shed light on some of the state-of-the-art CL methods. Finally, through extensive empirical evaluations, we demonstrate that our theoretical findings are also applicable to deep neural networks, offering valuable guidance for designing MTL and CL models in practice.
Paper Structure (45 sections, 13 theorems, 87 equations, 20 figures, 5 tables)

This paper contains 45 sections, 13 theorems, 87 equations, 20 figures, 5 tables.

Key Result

Theorem 3.1

(hastie2022surprises) When $n_t \geq p + 2$, the single-task learner described in Equation eq:train-loss achieves and when $p \geq n_t+2$, the single-task learner described in Equation eq:overparam-loss obtains where the expectation is due to randomness of $X_t$ and $z_t$.

Figures (20)

  • Figure 1: The average generalization error of multi-task, single-task and continual learners w.r.t. to the model size $p$. (a) and (b) compare the MTL and STL for different levels of noise strength $\sigma$. (c) and (d) present the generalization error of replay-based continual learners for different memory sizes $m$ with zero noise. For all plots, $T=10$, and for all tasks $n_t = 50$ and $\Vert w_t^* \Vert ^ 2 = 1$. In subfigures (a) and (c), the tasks are designed to be collaborative by adjusting $\langle w^*_t,w^*_{t'} \rangle = \cos \frac{\pi}{8}$ for every pair of task vectors. In subfigures (b) and (d), the tasks are chosen to be conflicting by setting $\langle w^*_t,w^*_{t'} \rangle = \cos \frac{7\pi}{8}$. The solid lines represent the results from theoretical predictions. The dot marks are the empirical evaluations averaged over 500 repetitions and are perfectly aligned with the theoretical results in the overparameterized regime.
  • Figure 2: Comparing the performance of the single-task vs multi-task learners on CIFAR-100 using pretrained ResNet-18 backbone. The vertical axis is the cross-entropy loss and the horizontal axis is the width of the MLP on top of the backbone. $\sigma$ represents the noisy portion of the training set that was corrupted by randomly switching its labels. (a) and (b) correspond to the train and test loss with the single-head classifier, while (c) and (d) show the train and test loss of the multi-head architecture. Similar to linear models studied previously, deep MTL models also undergo the double descent behaviors with the test error peak depending on the characteristics of the task similarity and strength of the noise.
  • Figure 3: The average across-task inner product of task-specific optimal weights at different layers for different values of MLP width, $k$. As this figure shows, the weights of the classification head has the most negative inner products which means that the tasks are highly conflicting in this layer.
  • Figure 4: Test loss of the single-task (Top) vs multi-task (Bottom) learners on CIFAR-100 with unfrozen ResNet-18. The vertical axis is the cross-entropy loss and the horizontal axis is the scale factor that controls the number of filters in convolutional layers. As we observe here, the test loss behaves similarly to the case where only the MLP head is trained.
  • Figure 5: Test loss of the single-task vs multi-task learners on different datasets with pretrained backbones. The vertical axis is the cross-entropy loss and the horizontal axis is the width of the MLP on top of the backbone. $\sigma$ represents the noisy portion of the training set that was corrupted by randomly switching its labels. The multi-head architecture was used in all MTL experiments. These plots highlight the difference between MTL and STL double descent across several practical backbones and datasets.
  • ...and 15 more figures

Theorems & Definitions (26)

  • Theorem 3.1
  • Theorem 3.2
  • Theorem 4.1
  • Theorem 4.2
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Lemma A.3
  • proof
  • ...and 16 more