Table of Contents
Fetching ...

Attention Retention for Continual Learning with Vision Transformers

Yue Lu, Xiangyu Zhou, Shizhou Zhang, Yinghui Xing, Guoqiang Liang, Wencong Zhang

TL;DR

This paper tackles catastrophic forgetting in continual learning by pinpointing attention drift in Vision Transformers as a key bottleneck. It introduces ARCL-ViT, an attention-retaining framework that constrains drift via gradient masking informed by layer-wise attention rollout and adaptive thresholding, making updates compatible with standard optimizers. The method demonstrates state-of-the-art performance and strong generalization across diverse pre-training regimes and long-sequence tasks, while providing ablations and analyses that validate the importance of attention-centered masking. Overall, the approach offers a principled, data-free mechanism to preserve previously learned concepts in ViTs during continual learning, with practical implications for scalable, robust continual vision systems.

Abstract

Continual learning (CL) empowers AI systems to progressively acquire knowledge from non-stationary data streams. However, catastrophic forgetting remains a critical challenge. In this work, we identify attention drift in Vision Transformers as a primary source of catastrophic forgetting, where the attention to previously learned visual concepts shifts significantly after learning new tasks. Inspired by neuroscientific insights into the selective attention in the human visual system, we propose a novel attention-retaining framework to mitigate forgetting in CL. Our method constrains attention drift by explicitly modifying gradients during backpropagation through a two-step process: 1) extracting attention maps of the previous task using a layer-wise rollout mechanism and generating instance-adaptive binary masks, and 2) when learning a new task, applying these masks to zero out gradients associated with previous attention regions, thereby preventing disruption of learned visual concepts. For compatibility with modern optimizers, the gradient masking process is further enhanced by scaling parameter updates proportionally to maintain their relative magnitudes. Experiments and visualizations demonstrate the effectiveness of our method in mitigating catastrophic forgetting and preserving visual concepts. It achieves state-of-the-art performance and exhibits robust generalizability across diverse CL scenarios.

Attention Retention for Continual Learning with Vision Transformers

TL;DR

This paper tackles catastrophic forgetting in continual learning by pinpointing attention drift in Vision Transformers as a key bottleneck. It introduces ARCL-ViT, an attention-retaining framework that constrains drift via gradient masking informed by layer-wise attention rollout and adaptive thresholding, making updates compatible with standard optimizers. The method demonstrates state-of-the-art performance and strong generalization across diverse pre-training regimes and long-sequence tasks, while providing ablations and analyses that validate the importance of attention-centered masking. Overall, the approach offers a principled, data-free mechanism to preserve previously learned concepts in ViTs during continual learning, with practical implications for scalable, robust continual vision systems.

Abstract

Continual learning (CL) empowers AI systems to progressively acquire knowledge from non-stationary data streams. However, catastrophic forgetting remains a critical challenge. In this work, we identify attention drift in Vision Transformers as a primary source of catastrophic forgetting, where the attention to previously learned visual concepts shifts significantly after learning new tasks. Inspired by neuroscientific insights into the selective attention in the human visual system, we propose a novel attention-retaining framework to mitigate forgetting in CL. Our method constrains attention drift by explicitly modifying gradients during backpropagation through a two-step process: 1) extracting attention maps of the previous task using a layer-wise rollout mechanism and generating instance-adaptive binary masks, and 2) when learning a new task, applying these masks to zero out gradients associated with previous attention regions, thereby preventing disruption of learned visual concepts. For compatibility with modern optimizers, the gradient masking process is further enhanced by scaling parameter updates proportionally to maintain their relative magnitudes. Experiments and visualizations demonstrate the effectiveness of our method in mitigating catastrophic forgetting and preserving visual concepts. It achieves state-of-the-art performance and exhibits robust generalizability across diverse CL scenarios.
Paper Structure (20 sections, 11 equations, 2 figures, 5 tables)

This paper contains 20 sections, 11 equations, 2 figures, 5 tables.

Figures (2)

  • Figure 1: (a) Visualization of attention maps on the first task's ($\mathcal{T}_1$) samples. Columns two to four show attention maps from ViT models that are (i) only fine-tuned on $\mathcal{T}_1$, (ii) sequentially fine-tuned over 10 tasks (Model-$\mathcal{T}_{10}$ Seq-FT), and (iii) trained with our method over 10 tasks. Seq-FT leads to severe attention drift, while our method effectively preserves the original attention and thus mitigates forgetting. (b) Quantitative comparison of attention drift relative to $\mathcal{T}_1$.
  • Figure 2: Illustration of our proposed attention-retaining continual learning (ARCL-ViT) framework. After finishing $\mathcal{T}_{t-1}$, the model extracts attention maps $\mathbf{U}_{t-1}$ and generates the mask $\bar{\mathbf{M}}_{t-1}$ to identify attention regions. During subsequent training in $\mathcal{T}_{t}$, the mask $\bar{\mathbf{M}}_{t-1}$ is used to selectively zero out gradients in the corresponding attention regions. The masked gradients $\nabla(\mathbf{W}_{q.t})^{\prime}, \nabla(\mathbf{W}_{k.t})^{\prime}, \nabla(\mathbf{W}_{v.t})^{\prime}$ are used to update the learnable weights $\mathbf{W}_{q.t}, \mathbf{W}_{k.t}, \mathbf{W}_{v.t}$.