Table of Contents
Fetching ...

Learning with Preserving for Continual Multitask Learning

Hanchen David Wang, Siwoo Bae, Zirong Chen, Meiyi Ma

TL;DR

This work defines Continual Multitask Learning (CMTL), where tasks arrive sequentially on a shared input domain and labels for past tasks are not fully available. It introduces Learning with Preserving (LwP), a replay-free framework that preserves the geometry of the latent representation via a Dynamically Weighted Distance Preservation (DWDP) loss, complemented by current-task supervision and distillation of past tasks. By maintaining pairwise distances within intra-task pairs and using a dynamic mask to avoid inter-class conflicts, LwP mitigates catastrophic forgetting while enabling knowledge sharing across tasks, and it demonstrates superior performance and robustness on time-series and image benchmarks, including non-stationary distributions. The approach does not require data replay, making it particularly suitable for privacy-sensitive scenarios, and it achieves state-of-the-art results across multiple datasets, often surpassing single-task baselines.

Abstract

Artificial intelligence systems in critical fields like autonomous driving and medical imaging analysis often continually learn new tasks using a shared stream of input data. For instance, after learning to detect traffic signs, a model may later need to learn to classify traffic lights or different types of vehicles using the same camera feed. This scenario introduces a challenging setting we term Continual Multitask Learning (CMTL), where a model sequentially learns new tasks on an underlying data distribution without forgetting previously learned abilities. Existing continual learning methods often fail in this setting because they learn fragmented, task-specific features that interfere with one another. To address this, we introduce Learning with Preserving (LwP), a novel framework that shifts the focus from preserving task outputs to maintaining the geometric structure of the shared representation space. The core of LwP is a Dynamically Weighted Distance Preservation (DWDP) loss that prevents representation drift by regularizing the pairwise distances between latent data representations. This mechanism of preserving the underlying geometric structure allows the model to retain implicit knowledge and support diverse tasks without requiring a replay buffer, making it suitable for privacy-conscious applications. Extensive evaluations on time-series and image benchmarks show that LwP not only mitigates catastrophic forgetting but also consistently outperforms state-of-the-art baselines in CMTL tasks. Notably, our method shows superior robustness to distribution shifts and is the only approach to surpass the strong single-task learning baseline, underscoring its effectiveness for real-world dynamic environments.

Learning with Preserving for Continual Multitask Learning

TL;DR

This work defines Continual Multitask Learning (CMTL), where tasks arrive sequentially on a shared input domain and labels for past tasks are not fully available. It introduces Learning with Preserving (LwP), a replay-free framework that preserves the geometry of the latent representation via a Dynamically Weighted Distance Preservation (DWDP) loss, complemented by current-task supervision and distillation of past tasks. By maintaining pairwise distances within intra-task pairs and using a dynamic mask to avoid inter-class conflicts, LwP mitigates catastrophic forgetting while enabling knowledge sharing across tasks, and it demonstrates superior performance and robustness on time-series and image benchmarks, including non-stationary distributions. The approach does not require data replay, making it particularly suitable for privacy-sensitive scenarios, and it achieves state-of-the-art results across multiple datasets, often surpassing single-task baselines.

Abstract

Artificial intelligence systems in critical fields like autonomous driving and medical imaging analysis often continually learn new tasks using a shared stream of input data. For instance, after learning to detect traffic signs, a model may later need to learn to classify traffic lights or different types of vehicles using the same camera feed. This scenario introduces a challenging setting we term Continual Multitask Learning (CMTL), where a model sequentially learns new tasks on an underlying data distribution without forgetting previously learned abilities. Existing continual learning methods often fail in this setting because they learn fragmented, task-specific features that interfere with one another. To address this, we introduce Learning with Preserving (LwP), a novel framework that shifts the focus from preserving task outputs to maintaining the geometric structure of the shared representation space. The core of LwP is a Dynamically Weighted Distance Preservation (DWDP) loss that prevents representation drift by regularizing the pairwise distances between latent data representations. This mechanism of preserving the underlying geometric structure allows the model to retain implicit knowledge and support diverse tasks without requiring a replay buffer, making it suitable for privacy-conscious applications. Extensive evaluations on time-series and image benchmarks show that LwP not only mitigates catastrophic forgetting but also consistently outperforms state-of-the-art baselines in CMTL tasks. Notably, our method shows superior robustness to distribution shifts and is the only approach to surpass the strong single-task learning baseline, underscoring its effectiveness for real-world dynamic environments.

Paper Structure

This paper contains 40 sections, 22 equations, 16 figures, 7 tables, 1 algorithm.

Figures (16)

  • Figure 1: Overview of the LwP framework. For the first task, $\mathcal{T}_1$ (e.g., Traffic Light), the model is trained on data $D_1$. When learning a subsequent task like $\mathcal{T}_2$ (Pedestrian), the model from $\mathcal{T}_1$ is frozen as a teacher. This process generalizes for any current task $\mathcal{T}_i$: the model from the previous step, $f_{\theta_{s}^{[i-1]}}$, acts as a teacher for the student model, $f_{\theta_{s}^{[i]}}$, which learns on new data $D_i$ using a supervised loss ($\mathcal{L}_{\text{cur}}$), a distillation loss ($\mathcal{L}_{\text{old}}$), and our geometric preservation loss ($\mathcal{L}_{\text{DWDP}}$).
  • Figure 2: A visualization of the latent representation space (depicted as a sphere) as new tasks are learned sequentially. The points represent data embeddings. The LwP framework organizes representations for new tasks into distinct clusters (colored points) while its preservation loss maintains the geometric structure of prior representations.
  • Figure 3: Selected matrices showcasing the accuracy progression for the dataset CelebA. More details in Appendix \ref{['appendix:additional_result']}.
  • Figure 4: Selected Backward Transfer Diagrams for the Benchmark Datasets. More details in Appendix \ref{['appendix:acc_per_iter']}
  • Figure 5: Impact of $\mathcal{L}_{\text{DWDP}}$ on a constructed example. After training on a concentric circle task, the representation space without our loss (left) degrades. With $\mathcal{L}_{\text{DWDP}}$ (right), the structure required for a subsequent XOR task is preserved.
  • ...and 11 more figures