Table of Contents
Fetching ...

Controllable Prompt Tuning For Balancing Group Distributional Robustness

Hoang Phan, Andrew Gordon Wilson, Qi Lei

TL;DR

This work tackles spurious correlations under distribution shift by balancing learning across multiple groups rather than optimizing only the worst group. It introduces Controllable Prompt Tuning (CPT), which couples a gradient-based, entropy-regularized multi-objective update with parameter-efficient prompt-tuning (for ViT and CLIP) to achieve uniform group performance with minimal trainable parameters. The key ideas are a K-dimensional per-group loss vector, a gradient update that maximizes entropy over group losses, and a controllable vector $\mathbf{c}$ to bias learning toward selected groups. Empirically, CPT attains state-of-the-art worst-group accuracy on Waterbirds and CelebA, yields substantial mean/average gains on multimodal backbones, and scales efficiently via prompts, enabling debiasing across transformer, vision-language, and unimodal settings with only $\approx$0.4% of parameters updated.

Abstract

Models trained on data composed of different groups or domains can suffer from severe performance degradation under distribution shifts. While recent methods have largely focused on optimizing the worst-group objective, this often comes at the expense of good performance on other groups. To address this problem, we introduce an optimization scheme to achieve good performance across groups and find a good solution for all without severely sacrificing performance on any of them. However, directly applying such optimization involves updating the parameters of the entire network, making it both computationally expensive and challenging. Thus, we introduce Controllable Prompt Tuning (CPT), which couples our approach with prompt-tuning techniques. On spurious correlation benchmarks, our procedures achieve state-of-the-art results across both transformer and non-transformer architectures, as well as unimodal and multimodal data, while requiring only 0.4% tunable parameters.

Controllable Prompt Tuning For Balancing Group Distributional Robustness

TL;DR

This work tackles spurious correlations under distribution shift by balancing learning across multiple groups rather than optimizing only the worst group. It introduces Controllable Prompt Tuning (CPT), which couples a gradient-based, entropy-regularized multi-objective update with parameter-efficient prompt-tuning (for ViT and CLIP) to achieve uniform group performance with minimal trainable parameters. The key ideas are a K-dimensional per-group loss vector, a gradient update that maximizes entropy over group losses, and a controllable vector to bias learning toward selected groups. Empirically, CPT attains state-of-the-art worst-group accuracy on Waterbirds and CelebA, yields substantial mean/average gains on multimodal backbones, and scales efficiently via prompts, enabling debiasing across transformer, vision-language, and unimodal settings with only 0.4% of parameters updated.

Abstract

Models trained on data composed of different groups or domains can suffer from severe performance degradation under distribution shifts. While recent methods have largely focused on optimizing the worst-group objective, this often comes at the expense of good performance on other groups. To address this problem, we introduce an optimization scheme to achieve good performance across groups and find a good solution for all without severely sacrificing performance on any of them. However, directly applying such optimization involves updating the parameters of the entire network, making it both computationally expensive and challenging. Thus, we introduce Controllable Prompt Tuning (CPT), which couples our approach with prompt-tuning techniques. On spurious correlation benchmarks, our procedures achieve state-of-the-art results across both transformer and non-transformer architectures, as well as unimodal and multimodal data, while requiring only 0.4% tunable parameters.
Paper Structure (22 sections, 1 theorem, 17 equations, 10 figures, 15 tables)

This paper contains 22 sections, 1 theorem, 17 equations, 10 figures, 15 tables.

Key Result

Theorem 4.1

Assume that the loss function $\ell$ is differentiable up to the first order with respect to $\theta$, then following where $p_i = \frac{e^{\ell_i (\theta)}}{\sum_{j=1}^{K}e^{\ell_j (\theta)}}$, maximizes the objective $\mathcal{L}_{ent}(\theta)$.

Figures (10)

  • Figure 1: Accuracy curves on Waterbirds for four groups during training. Vertical lines indicate early stopping epochs as models obtain the best performance on the validation set.
  • Figure 2: Overview of our method on the Waterbirds dataset. Our main objective is to not only improve model performance across groups by optimizing their loss functions $\ell_1, \ell_2, \ell_3, \ell_4$, but also maximize the entropy over this loss distribution.
  • Figure 3: Results of fine-tuning ViT backbones on Waterbirds. Error bars represent the standard deviation over independent runs.
  • Figure 4: Performance of ResNet50 at: the last checkpoint used for evaluation (highest worst group accuracy on validation set), denoted by $\bigtriangleup$, and the checkpoint where the performance on the minority group is highest, denoted by $\square$. Results are obtained on three random seeds.
  • Figure 5: GradCAM of ResNet50 on in-domain samples.
  • ...and 5 more figures

Theorems & Definitions (3)

  • Definition 3.1
  • Definition 3.2
  • Theorem 4.1