Prompt-aligned Gradient for Prompt Tuning
Beier Zhu, Yulei Niu, Yucheng Han, Yue Wu, Hanwang Zhang
TL;DR
Prompt tuning of vision-language models often overfits when data is scarce, leading to forgetting of the pre-trained general knowledge. ProGrad introduces a gradient-alignment mechanism that regularizes prompt updates by tying them to a general direction $G_g$, derived from the KL divergence between zero-shot and downstream predictions, while using a domain direction $G_d$ from cross-entropy with ground truth. The update rule projects conflicting downstream gradients onto the orthogonal subspace of the general direction, controlled by a hyperparameter $\lambda$, and updates only the learnable context tokens. Across 11 datasets and multiple generalization settings (few-shot, domain generalization, base-to-new, cross-dataset transfer), ProGrad consistently improves over CoOp and CoCoOp, showing enhanced robustness without changing the underlying architecture.
Abstract
Thanks to the large pre-trained vision-language models (VLMs) like CLIP, we can craft a zero-shot classifier by "prompt", e.g., the confidence score of an image being "[CLASS]" can be obtained by using the VLM provided similarity measure between the image and the prompt sentence "a photo of a [CLASS]". Therefore, prompt shows a great potential for fast adaptation of VLMs to downstream tasks if we fine-tune the prompt-based similarity measure. However, we find a common failure that improper fine-tuning may not only undermine the prompt's inherent prediction for the task-related classes, but also for other classes in the VLM vocabulary. Existing methods still address this problem by using traditional anti-overfitting techniques such as early stopping and data augmentation, which lack a principled solution specific to prompt. We present Prompt-aligned Gradient, dubbed ProGrad, to prevent prompt tuning from forgetting the the general knowledge learned from VLMs. In particular, ProGrad only updates the prompt whose gradient is aligned (or non-conflicting) to the "general direction", which is represented as the gradient of the KL loss of the pre-defined prompt prediction. Extensive experiments demonstrate the stronger few-shot generalization ability of ProGrad over state-of-the-art prompt tuning methods. Codes are available at https://github.com/BeierZhu/Prompt-align.
