Table of Contents
Fetching ...

OTTER: Effortless Label Distribution Adaptation of Zero-shot Models

Changho Shin, Jitian Zhao, Sonia Cromp, Harit Vishwakarma, Frederic Sala

TL;DR

A simple and lightweight approach to adjust pretrained model predictions via optimal transport requires only an estimate of the label distribution of a downstream task and is validated in a wide array of zero-shot image and text classification tasks.

Abstract

Popular zero-shot models suffer due to artifacts inherited from pretraining. One particularly detrimental issue, caused by unbalanced web-scale pretraining data, is mismatched label distribution. Existing approaches that seek to repair the label distribution are not suitable in zero-shot settings, as they have mismatching requirements, such as needing access to labeled downstream task data or knowledge of the true label balance in the pretraining distribution. We sidestep these challenges and introduce a simple and lightweight approach to adjust pretrained model predictions via optimal transport. Our technique requires only an estimate of the label distribution of a downstream task. Theoretically, we characterize the improvement produced by our procedure under certain mild conditions and provide bounds on the error caused by misspecification. Empirically, we validate our method in a wide array of zero-shot image and text classification tasks, improving accuracy by 4.8% and 15.9% on average, and beating baselines like prior matching -- often by significant margins -- in 17 out of 21 datasets.

OTTER: Effortless Label Distribution Adaptation of Zero-shot Models

TL;DR

A simple and lightweight approach to adjust pretrained model predictions via optimal transport requires only an estimate of the label distribution of a downstream task and is validated in a wide array of zero-shot image and text classification tasks.

Abstract

Popular zero-shot models suffer due to artifacts inherited from pretraining. One particularly detrimental issue, caused by unbalanced web-scale pretraining data, is mismatched label distribution. Existing approaches that seek to repair the label distribution are not suitable in zero-shot settings, as they have mismatching requirements, such as needing access to labeled downstream task data or knowledge of the true label balance in the pretraining distribution. We sidestep these challenges and introduce a simple and lightweight approach to adjust pretrained model predictions via optimal transport. Our technique requires only an estimate of the label distribution of a downstream task. Theoretically, we characterize the improvement produced by our procedure under certain mild conditions and provide bounds on the error caused by misspecification. Empirically, we validate our method in a wide array of zero-shot image and text classification tasks, improving accuracy by 4.8% and 15.9% on average, and beating baselines like prior matching -- often by significant margins -- in 17 out of 21 datasets.
Paper Structure (63 sections, 10 theorems, 47 equations, 7 figures, 14 tables, 2 algorithms)

This paper contains 63 sections, 10 theorems, 47 equations, 7 figures, 14 tables, 2 algorithms.

Key Result

Theorem 4.1

Let $\nu^{ZS}_j=\frac{1}{n}\sum_{i=1}^n\mathbbm{1}[\hat{y}^{ZS}_i=j]$, where $\hat{y}^{ZS}_{i} = \arg \max_{j' \in [K]} P_\theta(Y=j'|X=x_i)$. Then, given $C_{ij} = -\log P_\theta(Y=j|X=x_i)$, Assuming there are no ties in scores, i.e. $P_\theta(Y=j|X=x_i) \neq P_\theta(Y=j'|X=x_i), \text{ for all } j \neq j'$, the OTTER predictions are equivalent zero-shot predictions, i.e. $\hat{y}^{OT}_i = \h

Figures (7)

  • Figure 1: Label distribution mismatch example in zero-shot classification. In the Oxford-IIIT-Pet dataset, the ground-truth labels are uniformly distributed, while zero-shot models exhibit biased predictions toward certain classes. This bias is influenced by the distribution of labels in the pretraining task.
  • Figure 2: Synthetic experiment results. X-axis represents total variation distance between the source and the target distribution, describing label shift severity. Y-axis represents prediction accuracy. Curves represent different methods and noise levels. Our approaches dramatically outperform the baseline at higher mismatch levels.
  • Figure 3: Ablation on the number of samples in few-shot learning. In (a), We can observe that BBSE estimation get more precise as the number of samples increases. Following this, OTTER gets better accuracy in (b). Additionally, OTTER consistently improves linear probing when combined.
  • Figure 4: Ablation experiment on the class balance specification. X-axis represents the total variation distance between the class specification true class balance $P_t(Y)$ and $\hat{P}_t(Y)$. Y-axis represents accuracy. ViT-B/16 is used as the image zero-shot classifier, and BERT is used as the text zero-shot classifier.
  • Figure 5: Ablation experiment on the number of samples. We report the mean of 10 different samplings in each setting. We use ViT-B/16 for image classification, and BERT for text classification.
  • ...and 2 more figures

Theorems & Definitions (16)

  • Theorem 4.1
  • Theorem 4.2
  • Theorem 4.3
  • Theorem 4.4
  • proof
  • Theorem D.1
  • proof
  • Corollary D.2
  • proof
  • Lemma D.3: robinson1973bounds, Corollary 3.1.
  • ...and 6 more