Table of Contents
Fetching ...

Prompt-aligned Gradient for Prompt Tuning

Beier Zhu, Yulei Niu, Yucheng Han, Yue Wu, Hanwang Zhang

TL;DR

Prompt tuning of vision-language models often overfits when data is scarce, leading to forgetting of the pre-trained general knowledge. ProGrad introduces a gradient-alignment mechanism that regularizes prompt updates by tying them to a general direction $G_g$, derived from the KL divergence between zero-shot and downstream predictions, while using a domain direction $G_d$ from cross-entropy with ground truth. The update rule projects conflicting downstream gradients onto the orthogonal subspace of the general direction, controlled by a hyperparameter $\lambda$, and updates only the learnable context tokens. Across 11 datasets and multiple generalization settings (few-shot, domain generalization, base-to-new, cross-dataset transfer), ProGrad consistently improves over CoOp and CoCoOp, showing enhanced robustness without changing the underlying architecture.

Abstract

Thanks to the large pre-trained vision-language models (VLMs) like CLIP, we can craft a zero-shot classifier by "prompt", e.g., the confidence score of an image being "[CLASS]" can be obtained by using the VLM provided similarity measure between the image and the prompt sentence "a photo of a [CLASS]". Therefore, prompt shows a great potential for fast adaptation of VLMs to downstream tasks if we fine-tune the prompt-based similarity measure. However, we find a common failure that improper fine-tuning may not only undermine the prompt's inherent prediction for the task-related classes, but also for other classes in the VLM vocabulary. Existing methods still address this problem by using traditional anti-overfitting techniques such as early stopping and data augmentation, which lack a principled solution specific to prompt. We present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general knowledge learned from VLMs. In particular, ProGrad only updates the prompt whose gradient is aligned (or non-conflicting) to the "general direction", which is represented as the gradient of the KL loss of the pre-defined prompt prediction. Extensive experiments demonstrate the stronger few-shot generalization ability of ProGrad over state-of-the-art prompt tuning methods. Codes are available at https://github.com/BeierZhu/Prompt-align.

Prompt-aligned Gradient for Prompt Tuning

TL;DR

Prompt tuning of vision-language models often overfits when data is scarce, leading to forgetting of the pre-trained general knowledge. ProGrad introduces a gradient-alignment mechanism that regularizes prompt updates by tying them to a general direction , derived from the KL divergence between zero-shot and downstream predictions, while using a domain direction from cross-entropy with ground truth. The update rule projects conflicting downstream gradients onto the orthogonal subspace of the general direction, controlled by a hyperparameter , and updates only the learnable context tokens. Across 11 datasets and multiple generalization settings (few-shot, domain generalization, base-to-new, cross-dataset transfer), ProGrad consistently improves over CoOp and CoCoOp, showing enhanced robustness without changing the underlying architecture.

Abstract

Thanks to the large pre-trained vision-language models (VLMs) like CLIP, we can craft a zero-shot classifier by "prompt", e.g., the confidence score of an image being "[CLASS]" can be obtained by using the VLM provided similarity measure between the image and the prompt sentence "a photo of a [CLASS]". Therefore, prompt shows a great potential for fast adaptation of VLMs to downstream tasks if we fine-tune the prompt-based similarity measure. However, we find a common failure that improper fine-tuning may not only undermine the prompt's inherent prediction for the task-related classes, but also for other classes in the VLM vocabulary. Existing methods still address this problem by using traditional anti-overfitting techniques such as early stopping and data augmentation, which lack a principled solution specific to prompt. We present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general knowledge learned from VLMs. In particular, ProGrad only updates the prompt whose gradient is aligned (or non-conflicting) to the "general direction", which is represented as the gradient of the KL loss of the pre-defined prompt prediction. Extensive experiments demonstrate the stronger few-shot generalization ability of ProGrad over state-of-the-art prompt tuning methods. Codes are available at https://github.com/BeierZhu/Prompt-align.
Paper Structure (20 sections, 1 theorem, 15 equations, 7 figures, 12 tables)

This paper contains 20 sections, 1 theorem, 15 equations, 7 figures, 12 tables.

Key Result

Theorem 1

Let $\mathbf{X}_1^{N_d}=\{\mathbf{x}_n^{(d)}\}_{n=1}^{N_d}$ and $\mathbf{X}_1^{N_p}=\{\mathbf{x}_n^{(p)}\}_{n=1}^{N_p}$ be two set of i.i.d. samples drawn from the downstream domain $\mathcal{D}_{d}$ and the pre-trained domain $\mathcal{D}_{p}$. Then for any $\epsilon>0$, we have with probability at where $\gamma_\mathcal{F}(D,P)$ is the integral probability metric muller1997integral that measures

Figures (7)

  • Figure 1: Comparison of Zero-shot CLIP, CoOp, and our ProGrad on Stanford Cars and OxfordPets datasets. (a)&(b): Given 1 shot training sample, CoOp's performance severely drops and under-performs zero-shot CLIP by large margins when the training continues. (c)&(d): CoOp may fail to improve CLIP without data augmentation or plenty of samples.
  • Figure 2: Comparisons of Grad-CAM gradcam visualization for prompt tuning methods using different gradient strategies on Stanford Cars Datasets.
  • Figure 3: (a) If $\bm{G}_\text{d}$ is aligned with $\bm{G}_\text{g}$, we set $\bm{G}_\text{prograd}$ as $\bm{G}_\text{d}$. (b) If $\bm{G}_\text{d}$ conflicts with $\bm{G}_\text{g}$ (i.e., their angle is larger than 90°), we set $\bm{G}_\text{prograd}$ as the projection of $\bm{G}_\text{d}$ on the orthogonal direction of $\bm{G}_\text{g}$. (c) Training pipeline of our ProGrad. Only the context vectors are learnable.
  • Figure 4: Accuracy (%) of few-shot learning on 11 datasets. The context length $M$ is set to 16. Standard deviations are reported in Appendix.
  • Figure 5: Distribution of samples that are mis-classified by ProGrad but correctly classified by CoOp.
  • ...and 2 more figures

Theorems & Definitions (1)

  • Theorem 1