Table of Contents
Fetching ...

Align Your Prompts: Test-Time Prompting with Distribution Alignment for Zero-Shot Generalization

Jameel Hassan, Hanan Gani, Noor Hussein, Muhammad Uzair Khattak, Muzammal Naseer, Fahad Shahbaz Khan, Salman Khan

TL;DR

PromptAlign tackles the core challenge of zero-shot generalization under distribution shift for vision-language models by introducing test-time distribution alignment. It jointly tunes multi-modal prompts and aligns token distributions of test inputs to offline source statistics using a proxy dataset (commonly ImageNet) and augmented views of a single test sample. The alignment is formalized through $\mathcal{L}_{\text{align}} = \frac{1}{L} \sum_{l=1}^{L} ( || \mu_l(T;p) - \hat{\mu}_l ||_1 + || \sigma^2_l(T;p) - \hat{\sigma}^2_l ||_1 )$ and combined with the entropy objective as $\mathcal{L}_{\text{final}} = \mathcal{L}_{\text{entropy}} + \beta \mathcal{L}_{\text{align}}$, enabling updates to prompts on both image and text branches. Experiments on domain generalization and cross-dataset transfer show consistent improvements over baselines like MaPLe and TPT, with notable gains in domain generalization (e.g., 3.08% average Top-1 improvement) and robust cross-dataset performance. The approach demonstrates that token distribution alignment can significantly narrow the train-test distribution gap for CLIP-like models, using a scalable proxy dataset and minimal runtime overhead.

Abstract

The promising zero-shot generalization of vision-language models such as CLIP has led to their adoption using prompt learning for numerous downstream tasks. Previous works have shown test-time prompt tuning using entropy minimization to adapt text prompts for unseen domains. While effective, this overlooks the key cause for performance degradation to unseen domains -- distribution shift. In this work, we explicitly handle this problem by aligning the out-of-distribution (OOD) test sample statistics to those of the source data using prompt tuning. We use a single test sample to adapt multi-modal prompts at test time by minimizing the feature distribution shift to bridge the gap in the test domain. Evaluating against the domain generalization benchmark, our method improves zero-shot top- 1 accuracy beyond existing prompt-learning techniques, with a 3.08% improvement over the baseline MaPLe. In cross-dataset generalization with unseen categories across 10 datasets, our method improves consistently across all datasets compared to the existing state-of-the-art. Our source code and models are available at https://jameelhassan.github.io/promptalign.

Align Your Prompts: Test-Time Prompting with Distribution Alignment for Zero-Shot Generalization

TL;DR

PromptAlign tackles the core challenge of zero-shot generalization under distribution shift for vision-language models by introducing test-time distribution alignment. It jointly tunes multi-modal prompts and aligns token distributions of test inputs to offline source statistics using a proxy dataset (commonly ImageNet) and augmented views of a single test sample. The alignment is formalized through and combined with the entropy objective as , enabling updates to prompts on both image and text branches. Experiments on domain generalization and cross-dataset transfer show consistent improvements over baselines like MaPLe and TPT, with notable gains in domain generalization (e.g., 3.08% average Top-1 improvement) and robust cross-dataset performance. The approach demonstrates that token distribution alignment can significantly narrow the train-test distribution gap for CLIP-like models, using a scalable proxy dataset and minimal runtime overhead.

Abstract

The promising zero-shot generalization of vision-language models such as CLIP has led to their adoption using prompt learning for numerous downstream tasks. Previous works have shown test-time prompt tuning using entropy minimization to adapt text prompts for unseen domains. While effective, this overlooks the key cause for performance degradation to unseen domains -- distribution shift. In this work, we explicitly handle this problem by aligning the out-of-distribution (OOD) test sample statistics to those of the source data using prompt tuning. We use a single test sample to adapt multi-modal prompts at test time by minimizing the feature distribution shift to bridge the gap in the test domain. Evaluating against the domain generalization benchmark, our method improves zero-shot top- 1 accuracy beyond existing prompt-learning techniques, with a 3.08% improvement over the baseline MaPLe. In cross-dataset generalization with unseen categories across 10 datasets, our method improves consistently across all datasets compared to the existing state-of-the-art. Our source code and models are available at https://jameelhassan.github.io/promptalign.
Paper Structure (25 sections, 6 equations, 8 figures, 15 tables)

This paper contains 25 sections, 6 equations, 8 figures, 15 tables.

Figures (8)

  • Figure 1: (a) PromptAlign matches the distribution statistics $\bm{\mu}_{l}(\mathcal{T} ; \bm{p})$, $\bm{\sigma^2}_{l}(\mathcal{T} ; \bm{p})$, obtained from multiple augmented views of a single test sample, with the source data distribution statistics $\bm{\hat{\mu}}_{l}$, $\bm{\hat{\sigma}^2}_{l}$. This effectively brings the test sample closer to the distribution of the source data, where the domain shift is denoted by $\blacktriangle_{1}$$\rightarrow$$\blacktriangle_{2}$. $\mathcal{T}$ denotes the distribution of the test sample, $\bm{p}$ represents the prompts that are updated and $l$ refers to the vision-backbone layers. (b) Owing to the distribution matching via prompts, PromptAlign surpasses the existing state-of-the-art prompt learning approaches on 8 out of 10 datasets in cross-dataset generalization benchmarks.
  • Figure 2: Overview of our proposed PromptAlign method for zero-shot image classification. At test time, a single test sample along with its augmented views is passed through the CLIP image encoder, and the text labels are passed to the CLIP text encoder. The token distribution statistics -- mean and variance -- of the test sample are aligned with the offline computed source data statistics using a distribution alignment loss. The resulting alignment loss from the distribution shift is combined with the entropy loss to update the multi-modal prompts.
  • Figure 3: Effect of the loss scaling factor $\beta$ on ImageNet. Scale factor plateaus after $\beta = 100$.
  • Figure 4: Effect of performance on the choice of loss function for distribution alignment.
  • Figure 5: Analysis of compute resource constraints on performance.(a) The Top-1 accuracy increases with the number of augmented views. (b) The Top-1 accuracy improves consistently with the number of prompt update steps. (c) Impact on latency with the number of prompt update steps is similar for both methods.
  • ...and 3 more figures