Table of Contents
Fetching ...

An Efficient Plugin Method for Metric Optimization of Black-Box Models

Siddartha Devic, Nurendra Choudhary, Anirudh Srinivasan, Sahika Genc, Branislav Kveton, Gaurush Hiranandani

TL;DR

This work tackles post-hoc adaptation of black-box multiclass predictors under distribution shift by learning class weights to reweight predictions. The CWPlugin method uses coordinate-wise line searches over restricted pairwise classifiers to produce a weight vector $\mathbf{w}\in[0,1]^m$ and forms $b_{\mathbf{w}}(x)=\arg\max_k b(x)_k\mathbf{w}_k$, optimizing metrics that are functions of the confusion matrix, including linear-diagonal metrics. The authors prove consistency for linear-diagonal metrics and provide finite-sample guarantees, with runtime speedups achievable via quasi-concavity and parallelization. Empirically, CWPlugin improves metric performance on tabular income data and various NLP tasks, particularly when the labeled target set is small, and remains competitive with or surpasses calibration and probing baselines before full model retraining. Overall, CWPlugin offers a scalable, strictly post-hoc approach to aligning black-box predictions with target distributions and bespoke evaluation metrics across diverse domains.

Abstract

Many machine learning algorithms and classifiers are available only via API queries as a ``black-box'' -- that is, the downstream user has no ability to change, re-train, or fine-tune the model on a particular target distribution. Indeed, the downstream user may not even have knowledge of the \emph{original} training distribution or performance metric used to construct and optimize the black-box model. We propose a simple and efficient method, Plugin, which \emph{post-processes} arbitrary multiclass predictions from any black-box classifier in order to simultaneously (1) adapt these predictions to a target distribution; and (2) optimize a particular metric of the confusion matrix. Importantly, Plugin is a completely \textit{post-hoc} method which does not rely on feature information, only requires a small amount of probabilistic predictions along with their corresponding true label, and optimizes metrics by querying. We empirically demonstrate that Plugin is both broadly applicable and has performance competitive with related methods on a variety of tabular and language tasks.

An Efficient Plugin Method for Metric Optimization of Black-Box Models

TL;DR

This work tackles post-hoc adaptation of black-box multiclass predictors under distribution shift by learning class weights to reweight predictions. The CWPlugin method uses coordinate-wise line searches over restricted pairwise classifiers to produce a weight vector and forms , optimizing metrics that are functions of the confusion matrix, including linear-diagonal metrics. The authors prove consistency for linear-diagonal metrics and provide finite-sample guarantees, with runtime speedups achievable via quasi-concavity and parallelization. Empirically, CWPlugin improves metric performance on tabular income data and various NLP tasks, particularly when the labeled target set is small, and remains competitive with or surpasses calibration and probing baselines before full model retraining. Overall, CWPlugin offers a scalable, strictly post-hoc approach to aligning black-box predictions with target distributions and bespoke evaluation metrics across diverse domains.

Abstract

Many machine learning algorithms and classifiers are available only via API queries as a ``black-box'' -- that is, the downstream user has no ability to change, re-train, or fine-tune the model on a particular target distribution. Indeed, the downstream user may not even have knowledge of the \emph{original} training distribution or performance metric used to construct and optimize the black-box model. We propose a simple and efficient method, Plugin, which \emph{post-processes} arbitrary multiclass predictions from any black-box classifier in order to simultaneously (1) adapt these predictions to a target distribution; and (2) optimize a particular metric of the confusion matrix. Importantly, Plugin is a completely \textit{post-hoc} method which does not rely on feature information, only requires a small amount of probabilistic predictions along with their corresponding true label, and optimizes metrics by querying. We empirically demonstrate that Plugin is both broadly applicable and has performance competitive with related methods on a variety of tabular and language tasks.

Paper Structure

This paper contains 25 sections, 6 theorems, 10 equations, 14 figures, 1 table, 1 algorithm.

Key Result

Proposition 2

Let the desired precision $\epsilon>0$ and sample $S = \{(b(x_i), y_i)\}_{i\in[n]}$ be given. Suppose that $S$ contains $m$ classes. Then, alg:plugin converges with a runtime of $O(mn/\epsilon)$.

Figures (14)

  • Figure 1: The setting of our work. As input (Left), our method takes arbitrary probabilistic, multiclass predictions (along with true labels) on a target distribution from a black-box model $b$. The bars are conditional label probabilities predicted by the base model on data points $x_1, x_2$, and $x_3$, and the x-axis shows classes. A metric of interest (e.g., Accuracy, F-measure, etc.) is also given as input. The $\textsc{CWPlugin}$ algorithm then post-processes these predictions in a black-box manner, without any re-training or fine-tuning of the underlying model. The resulting predictions (Right) enjoy improved performance on the selected metric of interest.
  • Figure 2: Distribution shift on US Census data; Mean and standard deviation across five validation set samples. (Left) Table showing test performance metrics at a validation set size of 50 samples. Using the proposed plugin method to adapt a classifier trained on California data to Texas data outperforms training a new classifier with only the (limited) available Texas data. (Right) Test F-measure performance across varying validation set size.
  • Figure 3: Mean and standard deviation across five validation set samples. (Top) lmtweets results for each method on each metric using a sized 160 validation set $S$. (Bottom) lmtweets test G-mean and F-measure performance across varying validation set size. Adapting the outputs of a black-box model with $\textsc{CWPlugin}$ outperforms other post-hoc adaptation techniques at $\leq 400$ samples. At $\geq400$ samples, fine-tuning a clean BERT model on the validation set (BERT-FT) starts performing better.
  • Figure 4: Mean and standard deviation across five runs. Results for lmemotions (top) and lmemotionsOOD (bottom) on G-mean and F-measure. $\textsc{CWPlugin}$ consistently performs well across metrics for smaller sample sizes relative to all tested baseline methods including fine-tuning a clean language model on only the validation set (BERT-FT).
  • Figure 5: (Left) Results for SNLI with label shift applied to the validation and test data for methods fit on $|S|=100$ validation samples. (Right) Results for ANLI with label noise on $|S|=250$ validation samples. In both cases, $\textsc{CWPlugin}$ performs favorably when compared to other baselines.
  • ...and 9 more figures

Theorems & Definitions (11)

  • Proposition 2
  • Lemma 3
  • proof : Proof of \ref{['prop:runtime']}
  • proof : Proof of \ref{['lemma:quasi-concave']}
  • Definition 4: Linear Diagonal Metric
  • Proposition 6
  • proof
  • Lemma 7: Proposition 2 hiranandani2019performance
  • Lemma 8: Prop. 5 of narasimhan2022consistent
  • Proposition 9
  • ...and 1 more