Table of Contents
Fetching ...

Continual Learning with Query-Only Attention

Gautham Bekal, Ashish Pujari, Scott David Kelly

TL;DR

Continual learning must balance preserving plasticity with avoiding catastrophic forgetting under non-stationary task streams. The authors introduce a query-only attention mechanism that removes keys and values but retains transformer inductive bias, achieving strong online continual-learning performance while reducing both plasticity loss and forgetting; it also offers computational advantages over full attention. They establish conceptual links between query-only attention, full attention, and model-agnostic meta-learning (MAML), and substantiate the approach with Hessian-based curvature analysis showing that maintaining the effective rank supports plasticity. Across Permuted MNIST, Split Image Net, and Slowly Changing Regression, the method demonstrates competitive forward adaptation and, when task identities are available, reduced forgetting, with lower computational cost than full attention.

Abstract

Continual learning involves learning from a stream of data without repetition of data points, a scenario that is inherently complex due to distributional shift across tasks. We propose a query-only attention mechanism that discards keys and values, yet preserves the core inductive bias of transformer architectures. In continual learning scenarios, this simplified mechanism significantly mitigates both loss of plasticity and catastrophic forgetting, outperforming baselines such as selective re-initialization. We establish a conceptual link between query-only attention, full transformer attention, and model agnostic meta-learning, framing them as instances of meta-learning. We further provide intuition for why query-based models and attention networks help preserve plasticity in continual settings. Finally, through preliminary Hessian spectrum analysis, we observe that models maintaining higher curvature rank across tasks tend to retain plasticity. Our findings suggest that full attention may not be essential for capturing the benefits of meta-learning in continual learning.

Continual Learning with Query-Only Attention

TL;DR

Continual learning must balance preserving plasticity with avoiding catastrophic forgetting under non-stationary task streams. The authors introduce a query-only attention mechanism that removes keys and values but retains transformer inductive bias, achieving strong online continual-learning performance while reducing both plasticity loss and forgetting; it also offers computational advantages over full attention. They establish conceptual links between query-only attention, full attention, and model-agnostic meta-learning (MAML), and substantiate the approach with Hessian-based curvature analysis showing that maintaining the effective rank supports plasticity. Across Permuted MNIST, Split Image Net, and Slowly Changing Regression, the method demonstrates competitive forward adaptation and, when task identities are available, reduced forgetting, with lower computational cost than full attention.

Abstract

Continual learning involves learning from a stream of data without repetition of data points, a scenario that is inherently complex due to distributional shift across tasks. We propose a query-only attention mechanism that discards keys and values, yet preserves the core inductive bias of transformer architectures. In continual learning scenarios, this simplified mechanism significantly mitigates both loss of plasticity and catastrophic forgetting, outperforming baselines such as selective re-initialization. We establish a conceptual link between query-only attention, full transformer attention, and model agnostic meta-learning, framing them as instances of meta-learning. We further provide intuition for why query-based models and attention networks help preserve plasticity in continual settings. Finally, through preliminary Hessian spectrum analysis, we observe that models maintaining higher curvature rank across tasks tend to retain plasticity. Our findings suggest that full attention may not be essential for capturing the benefits of meta-learning in continual learning.

Paper Structure

This paper contains 33 sections, 11 equations, 4 figures, 17 tables, 2 algorithms.

Figures (4)

  • Figure 1: The prediction is over 7500 tasks and each data-point in the graph is averaged over 100 tasks for all models except for MAML. For MAML, we run over only 75 tasks and is shown without averaging.
  • Figure 2: The prediction is over 9000 tasks and each data-point in the graph is averaged over 100 tasks for all the models except MAML. MAML is run over 500 tasks, averaged over 5 tasks.
  • Figure 3: The prediction is over 800 tasks and each data-point in the graph is averaged over 10 tasks for all models.
  • Figure 4: Effective rank co-varies with forward performance; dips align with reduced plasticity. Mean $\pm$ std over 3 seeds.