Table of Contents
Fetching ...

Learning to Learn without Forgetting using Attention

Anna Vettoruzzo, Joaquin Vanschoren, Mohamed-Rafik Bouguelia, Thorsteinn Rögnvaldsson

TL;DR

Continual learning faces catastrophic forgetting as tasks arrive in a non-stationary sequence. The authors propose a transformer-based meta-optimizer with attention that learns task-specific weight updates for a classifier, guided by a pre-trained task encoder and a lightweight feature extractor, and gated by importance scores to protect prior knowledge. Empirically, the approach yields strong forward and backward transfer on SplitMNIST, RotatedMNIST, and SplitCIFAR-100 with limited labeled data, without relying on a replay buffer, and ablations confirm the importance of the transformer and score mechanism. This work advances data-efficient, scalable CL by enabling selective, task-aware updates through meta-learning, with potential extensions to larger models and dynamic classifier heads for broader continual-learning settings.

Abstract

Continual learning (CL) refers to the ability to continually learn over time by accommodating new knowledge while retaining previously learned experience. While this concept is inherent in human learning, current machine learning methods are highly prone to overwrite previously learned patterns and thus forget past experience. Instead, model parameters should be updated selectively and carefully, avoiding unnecessary forgetting while optimally leveraging previously learned patterns to accelerate future learning. Since hand-crafting effective update mechanisms is difficult, we propose meta-learning a transformer-based optimizer to enhance CL. This meta-learned optimizer uses attention to learn the complex relationships between model parameters across a stream of tasks, and is designed to generate effective weight updates for the current task while preventing catastrophic forgetting on previously encountered tasks. Evaluations on benchmark datasets like SplitMNIST, RotatedMNIST, and SplitCIFAR-100 affirm the efficacy of the proposed approach in terms of both forward and backward transfer, even on small sets of labeled data, highlighting the advantages of integrating a meta-learned optimizer within the continual learning framework.

Learning to Learn without Forgetting using Attention

TL;DR

Continual learning faces catastrophic forgetting as tasks arrive in a non-stationary sequence. The authors propose a transformer-based meta-optimizer with attention that learns task-specific weight updates for a classifier, guided by a pre-trained task encoder and a lightweight feature extractor, and gated by importance scores to protect prior knowledge. Empirically, the approach yields strong forward and backward transfer on SplitMNIST, RotatedMNIST, and SplitCIFAR-100 with limited labeled data, without relying on a replay buffer, and ablations confirm the importance of the transformer and score mechanism. This work advances data-efficient, scalable CL by enabling selective, task-aware updates through meta-learning, with potential extensions to larger models and dynamic classifier heads for broader continual-learning settings.

Abstract

Continual learning (CL) refers to the ability to continually learn over time by accommodating new knowledge while retaining previously learned experience. While this concept is inherent in human learning, current machine learning methods are highly prone to overwrite previously learned patterns and thus forget past experience. Instead, model parameters should be updated selectively and carefully, avoiding unnecessary forgetting while optimally leveraging previously learned patterns to accelerate future learning. Since hand-crafting effective update mechanisms is difficult, we propose meta-learning a transformer-based optimizer to enhance CL. This meta-learned optimizer uses attention to learn the complex relationships between model parameters across a stream of tasks, and is designed to generate effective weight updates for the current task while preventing catastrophic forgetting on previously encountered tasks. Evaluations on benchmark datasets like SplitMNIST, RotatedMNIST, and SplitCIFAR-100 affirm the efficacy of the proposed approach in terms of both forward and backward transfer, even on small sets of labeled data, highlighting the advantages of integrating a meta-learned optimizer within the continual learning framework.
Paper Structure (19 sections, 1 equation, 8 figures, 4 tables, 2 algorithms)

This paper contains 19 sections, 1 equation, 8 figures, 4 tables, 2 algorithms.

Figures (8)

  • Figure 1: (a) Overall framework of the proposed approach. The support set of the current data batch $D_{i,m}^{(sp)}$ is fed into the classifier model $f_{\mu, W}$ to derive importance scores $\{s_r\}_{r=1}^R$. These scores, along with the weights of the classifier model that we want to optimize $W = \{w_r\}_{r=1}^R$, and $D_{i,m}^{(sp)}$, serve as inputs to the meta-optimizer $g_\psi$ for predicting weight updates tailored to task $\mathcal{T}_m$ from which the data are sampled. (b) Architecture of the meta-optimizer. (A) The pre-trained task encoder learns a vector representation that characterizes the current task. (B) The feature extractor maps weight values into the feature space. (C) The transformer encoder predicts the weight updates.
  • Figure 2: Visualization of the importance scores $\{s_r\}_{r=1}^R$ and the weight updates learned by the meta-optimizer $g_{\psi}$ after adaptation to each task $\mathcal{T}_{i} \in {\bm{v}}{\mathcal{T}}$ for SplitMNIST dataset.
  • Figure 3: Pareto plot of average BWT vs. FWT
  • Figure 4: Visualization of the three metrics (average accuracy, BWT, FWT) varying the number of labeled examples in the support set of each task from $K=1$ to $K=20$, for all three datasets.
  • Figure 5: Visualizaton of the optimized weights at test time for all SplitMNIST tasks $\{\mathcal{T}_m\}_{m=1}^M$.
  • ...and 3 more figures