Table of Contents
Fetching ...

Continual Learning with Global Alignment

Xueying Bai, Jinghuan Shang, Yifan Sun, Niranjan Balasubramanian

TL;DR

This work tackles catastrophic forgetting in continual learning by analyzing cross-task gradient interference and proposing global alignment that grounds task-specific data representations in pre-trained token semantics. It introduces three transformer-based alignment models (Wire-Fixed, Wire-Neigh, C-LoRA) and a probing-first strategy for initializing class vectors, enabling strong performance without experience replay and robust class-incremental results after task-incremental training. By interpolating pre-trained token representations to form task data representations, the approach ties cross-task correlations to general semantic features, improving stability and transfer. Empirical results on multiple NLP task sequences demonstrate state-of-the-art forgetting resistance and compelling robustness across task configurations.

Abstract

Continual learning aims to sequentially learn new tasks without forgetting previous tasks' knowledge (catastrophic forgetting). One factor that can cause forgetting is the interference between the gradients on losses from different tasks. When the gradients on the current task's loss are in opposing directions to those on previous tasks' losses, updating the model for the current task may cause performance degradation on previous tasks. In this paper, we first identify causes of the above interference, and hypothesize that correlations between data representations are a key factor of interference. We then propose a method for promoting appropriate correlations between arbitrary tasks' data representations (i.e., global alignment) in individual task learning. Specifically, we learn the data representation as a task-specific composition of pre-trained token representations shared across all tasks. Then the correlations between different tasks' data representations are grounded by correlations between pre-trained token representations. We explore different ways to learn such compositions. Without experience replay, our model achieves SOTA performance in continual learning tasks. It also achieves advanced class-incremental performance through task-incremental training.

Continual Learning with Global Alignment

TL;DR

This work tackles catastrophic forgetting in continual learning by analyzing cross-task gradient interference and proposing global alignment that grounds task-specific data representations in pre-trained token semantics. It introduces three transformer-based alignment models (Wire-Fixed, Wire-Neigh, C-LoRA) and a probing-first strategy for initializing class vectors, enabling strong performance without experience replay and robust class-incremental results after task-incremental training. By interpolating pre-trained token representations to form task data representations, the approach ties cross-task correlations to general semantic features, improving stability and transfer. Empirical results on multiple NLP task sequences demonstrate state-of-the-art forgetting resistance and compelling robustness across task configurations.

Abstract

Continual learning aims to sequentially learn new tasks without forgetting previous tasks' knowledge (catastrophic forgetting). One factor that can cause forgetting is the interference between the gradients on losses from different tasks. When the gradients on the current task's loss are in opposing directions to those on previous tasks' losses, updating the model for the current task may cause performance degradation on previous tasks. In this paper, we first identify causes of the above interference, and hypothesize that correlations between data representations are a key factor of interference. We then propose a method for promoting appropriate correlations between arbitrary tasks' data representations (i.e., global alignment) in individual task learning. Specifically, we learn the data representation as a task-specific composition of pre-trained token representations shared across all tasks. Then the correlations between different tasks' data representations are grounded by correlations between pre-trained token representations. We explore different ways to learn such compositions. Without experience replay, our model achieves SOTA performance in continual learning tasks. It also achieves advanced class-incremental performance through task-incremental training.
Paper Structure (25 sections, 10 equations, 4 figures, 5 tables)

This paper contains 25 sections, 10 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: Overview of our methods. Task $i$'s data representations are denoted as $\mathbf h_i$ with pre-trained token representations as grey dots in the 'Representation' block. Correlations between aligned data representations from different tasks depends on correlations between pre-trained token representations. In the 'Class Vectors' block, class vectors for different classes have different focuses on representations after probing, which can reduce interference caused by overlapped representations.
  • Figure 2: T-SNE plots of all tasks' data representations after learning the first (with classes Village, Athlete) and last task. Under the vanilla sequential learning in (a), after the first task, representations of data from unseen tasks are overlapped. This may cause interference when switching tasks, which makes representations indistinguishable after learning the last task. With our global alignment model (Wire-Neigh) in (b), representations remain distinguishable after the first and last tasks.
  • Figure 3: Comparison between alignment models. Modules in blue are pre-trained and in orange are learnable. Representations in grey are mainly adapted and in blue are close to the pre-trained ones. We specify hidden representations for [CLS] and any other token as ${\bf h}_{\texttt{[CLS]}}^{l}$ and ${\bf h}_{\text{others}}^{l}$.
  • Figure 4: (a). Class-IL accuracy after the last task. Dashed lines show accuracies of ERACE, which replays previous tasks' data with Class-IL loss. (b). Average Class-IL accuracies after each task.