Table of Contents
Fetching ...

One-stage Prompt-based Continual Learning

Youngeun Kim, Yuhang Li, Priyadarshini Panda

TL;DR

This work tackles the heavy computational burden of Prompt-based Continual Learning (PCL) by introducing OS-Prompt, a one-stage PCL that uses intermediate layer token embeddings as prompt queries to eliminate a separate query ViT stage, achieving about a $50\%$ reduction in GFLOPs with a marginal accuracy drop of $<1\%$. To counteract the reduced representational power of using early-layer queries, the authors add a Query-Pool Regularization (QR) loss that aligns the prompt-query relationships with a reference final-layer query, improving performance without adding inference cost. A strengthened variant, OS-Prompt++, incorporates QR loss during training to boost the prompt pool’s representational power. Experiments on CIFAR-100, ImageNet-R, and DomainNet show competitive final accuracy and outperformance over the previous state-of-the-art by roughly $1.4\%$ at similar costs, indicating practical efficiency gains for rehearsal-free continual learning on resource-limited devices.

Abstract

Prompt-based Continual Learning (PCL) has gained considerable attention as a promising continual learning solution as it achieves state-of-the-art performance while preventing privacy violation and memory overhead issues. Nonetheless, existing PCL approaches face significant computational burdens because of two Vision Transformer (ViT) feed-forward stages; one is for the query ViT that generates a prompt query to select prompts inside a prompt pool; the other one is a backbone ViT that mixes information between selected prompts and image tokens. To address this, we introduce a one-stage PCL framework by directly using the intermediate layer's token embedding as a prompt query. This design removes the need for an additional feed-forward stage for query ViT, resulting in ~50% computational cost reduction for both training and inference with marginal accuracy drop < 1%. We further introduce a Query-Pool Regularization (QR) loss that regulates the relationship between the prompt query and the prompt pool to improve representation power. The QR loss is only applied during training time, so there is no computational overhead at inference from the QR loss. With the QR loss, our approach maintains ~ 50% computational cost reduction during inference as well as outperforms the prior two-stage PCL methods by ~1.4% on public class-incremental continual learning benchmarks including CIFAR-100, ImageNet-R, and DomainNet.

One-stage Prompt-based Continual Learning

TL;DR

This work tackles the heavy computational burden of Prompt-based Continual Learning (PCL) by introducing OS-Prompt, a one-stage PCL that uses intermediate layer token embeddings as prompt queries to eliminate a separate query ViT stage, achieving about a reduction in GFLOPs with a marginal accuracy drop of . To counteract the reduced representational power of using early-layer queries, the authors add a Query-Pool Regularization (QR) loss that aligns the prompt-query relationships with a reference final-layer query, improving performance without adding inference cost. A strengthened variant, OS-Prompt++, incorporates QR loss during training to boost the prompt pool’s representational power. Experiments on CIFAR-100, ImageNet-R, and DomainNet show competitive final accuracy and outperformance over the previous state-of-the-art by roughly at similar costs, indicating practical efficiency gains for rehearsal-free continual learning on resource-limited devices.

Abstract

Prompt-based Continual Learning (PCL) has gained considerable attention as a promising continual learning solution as it achieves state-of-the-art performance while preventing privacy violation and memory overhead issues. Nonetheless, existing PCL approaches face significant computational burdens because of two Vision Transformer (ViT) feed-forward stages; one is for the query ViT that generates a prompt query to select prompts inside a prompt pool; the other one is a backbone ViT that mixes information between selected prompts and image tokens. To address this, we introduce a one-stage PCL framework by directly using the intermediate layer's token embedding as a prompt query. This design removes the need for an additional feed-forward stage for query ViT, resulting in ~50% computational cost reduction for both training and inference with marginal accuracy drop < 1%. We further introduce a Query-Pool Regularization (QR) loss that regulates the relationship between the prompt query and the prompt pool to improve representation power. The QR loss is only applied during training time, so there is no computational overhead at inference from the QR loss. With the QR loss, our approach maintains ~ 50% computational cost reduction during inference as well as outperforms the prior two-stage PCL methods by ~1.4% on public class-incremental continual learning benchmarks including CIFAR-100, ImageNet-R, and DomainNet.
Paper Structure (20 sections, 11 equations, 6 figures, 9 tables)

This paper contains 20 sections, 11 equations, 6 figures, 9 tables.

Figures (6)

  • Figure 1: Difference between prior PCL work and ours. Prior PCL work (Left) has two feed-forward stages for (1) a query function (ViT) to select input-specific prompts and (2) a backbone ViT layer to perform prompt learning with the selected prompts and image tokens. On the other hand, our one-stage PCL framework (Middle) uses an intermediate layer's token embedding as a prompt query so that it requires only one backbone ViT feed-forward stage. As a result, our method reduces GFLOPs by $\sim 50\%$ compared to prior work while maintaining accuracy (Right).
  • Figure 2: We measure layer-wise feature distances between token embeddings of the model after training on task 1 and when a new task is learned. Each column represents the embeddings after training a model on Tasks 3, 5, 7, and 9. For instance, the left figure represents the distance between the token embeddings of Task 1 and those learned when Task 3 is completed. We train prompts using CodaPrompt smith2023coda and use $1-CosSim(x, y)$ to measure the layer-wise distance on the training dataset. We use a CIFAR100 10-task setting. We provide more examples in the Supplementary Materials.
  • Figure 3: Our OS-Prompt (OS-Prompt++) framework. An image passes through the backbone ViT layers to get the final prediction. From layer 1 to layer 5, we prepend prompt tokens to the image tokens, which can be obtained in the following progress: ① We first compute the image token embedding from the previous layer. ② We use $[CLS]$ token as a prompt query used for measuring cosine similarity between prompt keys inside the prompt pool. ③ Based on the similarity, we do weighted sum values to obtain prompt tokens. ④ To further improve the accuracy, we present OS-Prompt++ . We integrate the query-pool regularization loss (dotted line), enabling the prompt pool to capture a stronger representation from the reference ViT.
  • Figure 4: Comparison of latency across PCL methods on different GPUs.
  • Figure 5: Analysis of accuracy $A_N$ with respect to prompt components (Left) and prompt length (Right). We use 10-taks ImageNet-R setting.
  • ...and 1 more figures