Table of Contents
Fetching ...

Active, anytime-valid risk controlling prediction sets

Ziyu Xu, Nikos Karampatziakis, Paul Mineiro

TL;DR

This method extends the notion of a risk controlling prediction set (RCPS) to the sequential setting, where it provides guarantees even when the data is collected adaptively, and ensures that the risk guarantee is anytime-valid, i.e., simultaneously holds at all time steps.

Abstract

Rigorously establishing the safety of black-box machine learning models concerning critical risk measures is important for providing guarantees about model behavior. Recently, Bates et. al. (JACM '24) introduced the notion of a risk controlling prediction set (RCPS) for producing prediction sets that are statistically guaranteed low risk from machine learning models. Our method extends this notion to the sequential setting, where we provide guarantees even when the data is collected adaptively, and ensures that the risk guarantee is anytime-valid, i.e., simultaneously holds at all time steps. Further, we propose a framework for constructing RCPSes for active labeling, i.e., allowing one to use a labeling policy that chooses whether to query the true label for each received data point and ensures that the expected proportion of data points whose labels are queried are below a predetermined label budget. We also describe how to use predictors (i.e., the machine learning model for which we provide risk control guarantees) to further improve the utility of our RCPSes by estimating the expected risk conditioned on the covariates. We characterize the optimal choices of label policy and predictor under a fixed label budget and show a regret result that relates the estimation error of the optimal labeling policy and predictor to the wealth process that underlies our RCPSes. Lastly, we present practical ways of formulating label policies and empirically show that our label policies use fewer labels to reach higher utility than naive baseline labeling strategies on both simulations and real data.

Active, anytime-valid risk controlling prediction sets

TL;DR

This method extends the notion of a risk controlling prediction set (RCPS) to the sequential setting, where it provides guarantees even when the data is collected adaptively, and ensures that the risk guarantee is anytime-valid, i.e., simultaneously holds at all time steps.

Abstract

Rigorously establishing the safety of black-box machine learning models concerning critical risk measures is important for providing guarantees about model behavior. Recently, Bates et. al. (JACM '24) introduced the notion of a risk controlling prediction set (RCPS) for producing prediction sets that are statistically guaranteed low risk from machine learning models. Our method extends this notion to the sequential setting, where we provide guarantees even when the data is collected adaptively, and ensures that the risk guarantee is anytime-valid, i.e., simultaneously holds at all time steps. Further, we propose a framework for constructing RCPSes for active labeling, i.e., allowing one to use a labeling policy that chooses whether to query the true label for each received data point and ensures that the expected proportion of data points whose labels are queried are below a predetermined label budget. We also describe how to use predictors (i.e., the machine learning model for which we provide risk control guarantees) to further improve the utility of our RCPSes by estimating the expected risk conditioned on the covariates. We characterize the optimal choices of label policy and predictor under a fixed label budget and show a regret result that relates the estimation error of the optimal labeling policy and predictor to the wealth process that underlies our RCPSes. Lastly, we present practical ways of formulating label policies and empirically show that our label policies use fewer labels to reach higher utility than naive baseline labeling strategies on both simulations and real data.
Paper Structure (19 sections, 9 theorems, 45 equations, 3 figures)

This paper contains 19 sections, 9 theorems, 45 equations, 3 figures.

Key Result

Theorem 1

The sequence of estimates $(\widehat{\beta}_t)$ in eq:simple-beta-hat satisfies the anytime-valid risk control guarantee eq:anytime-valid-safety, i.e., $\mathbb{P}(\rho(\widehat{\beta}_t) \leq \theta\text{ for all }t \in \mathbb{N}) \geq 1 - \alpha$.

Figures (3)

  • Figure 1: Diagram of the active labeling setup for ensuring anytime-valid risk control.
  • Figure 2: Experimental results for different methods for our numerical simulation setup. We can see that "pretrain" and "learned" perform better by getting lower average $\widehat{\beta}_t$ uniformly across number of labels queried --- the dotted line in \ref{['fig:simulations-final', 'fig:simulations-label']} is $\beta^*=0.5578$. Each method also has low safety violation rate, i.e., is below the dotted line of $\alpha=0.05$ in \ref{['fig:simulations-error']}.
  • Figure 3: Experimental results for different methods on Imagenet. Again, we see that "pretrain" and "learned" are the best performing, and they have very similar performance and hence overlap in \ref{['fig:imagenet-label']}. Here, $\beta^*=0.8349$, and is delineated by the dotted line in \ref{['fig:imagenet-final', 'fig:imagenet-label']}. Again, each method also has low safety violation rate, i.e., is below the dotted line of $\alpha=0.05$ in \ref{['fig:imagenet-error']}.

Theorems & Definitions (17)

  • Definition 1
  • Definition 2
  • Theorem 1
  • proof
  • Proposition 1
  • proof
  • Remark 1
  • Proposition 2
  • proof
  • Theorem 2
  • ...and 7 more