Table of Contents
Fetching ...

Conformal Prediction Adaptive to Unknown Subpopulation Shifts

Nien-Shao Wang, Duygu Nur Yaldiz, Yavuz Faruk Bakman, Sai Praneeth Karimireddy

TL;DR

This work extends conformal prediction to unknown subpopulation shifts by leveraging (i) domain classifiers to weight calibration data, and (ii) domain-agnostic strategies based on embedding similarities and conformal risk control. The authors establish theoretical guarantees for marginal coverage under various relaxations of domain knowledge, including Bayes-optimal, multicalibrated, and multiaccurate classifiers, and demonstrate scalability to high-dimensional vision and language tasks. Empirical results on diverse benchmarks show that the proposed methods maintain coverage across numerous test environments, outperforming standard conformal prediction which can fail under shifts. The practical impact lies in robust uncertainty quantification for real-world AI systems, including reliable LLM hallucination detection under distribution changes and improved risk management in high-stakes applications.

Abstract

Conformal prediction is widely used to equip black-box machine learning models with uncertainty quantification, offering formal coverage guarantees under exchangeable data. However, these guarantees fail when faced with subpopulation shifts, where the test environment contains a different mix of subpopulations than the calibration data. In this work, we focus on unknown subpopulation shifts where we are not given group-information i.e. the subpopulation labels of datapoints have to be inferred. We propose new methods that provably adapt conformal prediction to such shifts, ensuring valid coverage without explicit knowledge of subpopulation structure. While existing methods in similar setups assume perfect subpopulation labels, our framework explicitly relaxes this requirement and characterizes conditions where formal coverage guarantees remain feasible. Further, our algorithms scale to high-dimensional settings and remain practical in realistic machine learning tasks. Extensive experiments on vision (with vision transformers) and language (with large language models) benchmarks demonstrate that our methods reliably maintain coverage and effectively control risks in scenarios where standard conformal prediction fails.

Conformal Prediction Adaptive to Unknown Subpopulation Shifts

TL;DR

This work extends conformal prediction to unknown subpopulation shifts by leveraging (i) domain classifiers to weight calibration data, and (ii) domain-agnostic strategies based on embedding similarities and conformal risk control. The authors establish theoretical guarantees for marginal coverage under various relaxations of domain knowledge, including Bayes-optimal, multicalibrated, and multiaccurate classifiers, and demonstrate scalability to high-dimensional vision and language tasks. Empirical results on diverse benchmarks show that the proposed methods maintain coverage across numerous test environments, outperforming standard conformal prediction which can fail under shifts. The practical impact lies in robust uncertainty quantification for real-world AI systems, including reliable LLM hallucination detection under distribution changes and improved risk management in high-stakes applications.

Abstract

Conformal prediction is widely used to equip black-box machine learning models with uncertainty quantification, offering formal coverage guarantees under exchangeable data. However, these guarantees fail when faced with subpopulation shifts, where the test environment contains a different mix of subpopulations than the calibration data. In this work, we focus on unknown subpopulation shifts where we are not given group-information i.e. the subpopulation labels of datapoints have to be inferred. We propose new methods that provably adapt conformal prediction to such shifts, ensuring valid coverage without explicit knowledge of subpopulation structure. While existing methods in similar setups assume perfect subpopulation labels, our framework explicitly relaxes this requirement and characterizes conditions where formal coverage guarantees remain feasible. Further, our algorithms scale to high-dimensional settings and remain practical in realistic machine learning tasks. Extensive experiments on vision (with vision transformers) and language (with large language models) benchmarks demonstrate that our methods reliably maintain coverage and effectively control risks in scenarios where standard conformal prediction fails.

Paper Structure

This paper contains 36 sections, 4 theorems, 28 equations, 7 figures, 5 tables, 3 algorithms.

Key Result

Theorem 2.1

Suppose we are given a algorithm $C_\alpha$ that given perfect group information (whether $X_{\text{test}} \sim \mathbb{P}_k$) obtains perfect domain-conditional coverage as in equation eq:group-conditional coverage. Then, there exist domain distributions $\{\mathbb{P}_k\}_{k\in [K]}$, and a domain

Figures (7)

  • Figure 1: (Left) Example of subpopulation shifts with 4 domains and 3 test environments. Each colored square represents data from a particular domain. Train and test environments are mixtures of the same set of domains but at different proportion. Score distributions (gray for train environment and blue for each test environment) and threshold calculated from standard conformal prediction are shown for each train/test environment. Subpopulation shifts leads to roughly the ideal coverage in test environment 1, whereas shifts for test environment 2 and 3 lead to significant under and over-coverage respectively. (Right) The same issue arises in LLM hallucination detection across different test environments. Standard LLM uncertainty estimation method (blue) is sensitive to distribution shifts displaying high variance in its hallucination detection recall, while the recall with our modification (orange) tightly follows the desired target recall.
  • Figure 2: Coverage distribution over 100 test environments with subpopulation shifts. (Left) Coverage across 100 test environments generated using Dirichlet sampling over 26 domains, and the averaged over 15 calibration/test splits. Mean and standard deviations are shown in the legend. (Right) Mean and standard deviation of coverage across 100 test environments. Note that max tends to substantially over cover compared to desired coverage of 0.95. Our algorithms (A1, A2, and oracle) demonstrate the desired coverage across test environments (unlike unweighted and Conditional Calibration that have significant under-coverage). They also have minimal over-coverage and tightly follow the target (unlike max which significantly over-covers). Further, the practical algorithms A1 and A2 quite closely match the ideal oracle coverage.
  • Figure 3: Adapting to subpopulation shifts without a domain classifier. Vision transformer is calibrated with LAC score function for various algorithms. For the results of Algorithm \ref{['alg:A3']}, the parameters $\sigma$ and $\beta$ are 0.7 and $0.1$ respectively. (Left) Coverage across 100 test environments at $\alpha=0.05$. Each coverage data is the average of 15 calibration/test splits. Mean and standard deviations are shown in the legend. (Right) Mean and standard deviation of coverage across 100 test environments. Our algorithm (A3 in pink) demonstrates the desired coverage of 0.95 across test environments with minimal over-coverage. Further, even without using any distributional or domain information, it matches the ideal coverage of the oracle (in green) which knows the test distribution exactly.
  • Figure 4: Controlling LLM hallucinations. LlaMA-3-8B was calibrated with 3 different score functions and test data were labeled according to \ref{['method:language']}. Recall was calculated with the standard deviation plotted. The standard deviation is across 100 different test environments, obtained by sampling Dirichlet distribution with $\alpha'=0.5$. Standard LLM uncertainty estimation method (blue) is sensitive to distribution shifts as evidenced by the high variance in recall across test-environments, while the recall with our method A3 (orange) tightly follows the desired target recall.
  • Figure 5: Distribution of coverage across different $1 - \alpha$. The results from 3 different model architectures (VisionTransformer, Resnet50, and Clip) and 3 different score functions (LAC, APS, and RAPS) are shown. For each sub-figure, the standard deviation across 100 test environments, sampled from Dirichlet distribution with $\alpha'=0.1$, is plotted. For each test environment, the coverage result is the average of 15 random calibration/test splits. The domain structure consists of 26 domains and 3 classes per domain. The results show that the proposed algorithms consistently outperform standard conformal prediction by having lower standard deviations across all model architectures, score functions, and $\alpha$.
  • ...and 2 more figures

Theorems & Definitions (8)

  • Theorem 2.1
  • Theorem 3.1
  • Definition 3.2
  • Theorem 3.3
  • Definition 3.4
  • Theorem 3.5
  • Remark 1
  • proof