Controllable Prompt Tuning For Balancing Group Distributional Robustness
Hoang Phan, Andrew Gordon Wilson, Qi Lei
TL;DR
This work tackles spurious correlations under distribution shift by balancing learning across multiple groups rather than optimizing only the worst group. It introduces Controllable Prompt Tuning (CPT), which couples a gradient-based, entropy-regularized multi-objective update with parameter-efficient prompt-tuning (for ViT and CLIP) to achieve uniform group performance with minimal trainable parameters. The key ideas are a K-dimensional per-group loss vector, a gradient update that maximizes entropy over group losses, and a controllable vector $\mathbf{c}$ to bias learning toward selected groups. Empirically, CPT attains state-of-the-art worst-group accuracy on Waterbirds and CelebA, yields substantial mean/average gains on multimodal backbones, and scales efficiently via prompts, enabling debiasing across transformer, vision-language, and unimodal settings with only $\approx$0.4% of parameters updated.
Abstract
Models trained on data composed of different groups or domains can suffer from severe performance degradation under distribution shifts. While recent methods have largely focused on optimizing the worst-group objective, this often comes at the expense of good performance on other groups. To address this problem, we introduce an optimization scheme to achieve good performance across groups and find a good solution for all without severely sacrificing performance on any of them. However, directly applying such optimization involves updating the parameters of the entire network, making it both computationally expensive and challenging. Thus, we introduce Controllable Prompt Tuning (CPT), which couples our approach with prompt-tuning techniques. On spurious correlation benchmarks, our procedures achieve state-of-the-art results across both transformer and non-transformer architectures, as well as unimodal and multimodal data, while requiring only 0.4% tunable parameters.
