Table of Contents
Fetching ...

Language Model Cascades: Token-level uncertainty and beyond

Neha Gupta, Harikrishna Narasimhan, Wittawat Jitkrittum, Ankit Singh Rawat, Aditya Krishna Menon, Sanjiv Kumar

TL;DR

The paper addresses the cost-quality tradeoff in generative language models by studying deferral rules in LM cascades. It demonstrates that sequence-level confidence measures like Chow-Sum and Chow-Average suffer from length bias, and introduces token-level uncertainty via Chow-Quantile, which, when combined with learned post-hoc deferral rules and optional embeddings from the smaller and larger models, yields stronger cost-quality performance. Empirical results on FLAN-T5 across MNLI, TriviaQA, and WMT tasks show consistent gains over baselines, and mid-model embeddings provide additional improvements. The work suggests practical, black-box cascade strategies that leverage token-level uncertainty and cross-model signals to enable more efficient inference for advanced NLP tasks.

Abstract

Recent advances in language models (LMs) have led to significant improvements in quality on complex NLP tasks, but at the expense of increased inference costs. Cascading offers a simple strategy to achieve more favorable cost-quality tradeoffs: here, a small model is invoked for most "easy" instances, while a few "hard" instances are deferred to the large model. While the principles underpinning cascading are well-studied for classification tasks - with deferral based on predicted class uncertainty favored theoretically and practically - a similar understanding is lacking for generative LM tasks. In this work, we initiate a systematic study of deferral rules for LM cascades. We begin by examining the natural extension of predicted class uncertainty to generative LM tasks, namely, the predicted sequence uncertainty. We show that this measure suffers from the length bias problem, either over- or under-emphasizing outputs based on their lengths. This is because LMs produce a sequence of uncertainty values, one for each output token; and moreover, the number of output tokens is variable across examples. To mitigate this issue, we propose to exploit the richer token-level uncertainty information implicit in generative LMs. We argue that naive predicted sequence uncertainty corresponds to a simple aggregation of these uncertainties. By contrast, we show that incorporating token-level uncertainty through learned post-hoc deferral rules can significantly outperform such simple aggregation strategies, via experiments on a range of natural language benchmarks with FLAN-T5 models. We further show that incorporating embeddings from the smaller model and intermediate layers of the larger model can give an additional boost in the overall cost-quality tradeoff.

Language Model Cascades: Token-level uncertainty and beyond

TL;DR

The paper addresses the cost-quality tradeoff in generative language models by studying deferral rules in LM cascades. It demonstrates that sequence-level confidence measures like Chow-Sum and Chow-Average suffer from length bias, and introduces token-level uncertainty via Chow-Quantile, which, when combined with learned post-hoc deferral rules and optional embeddings from the smaller and larger models, yields stronger cost-quality performance. Empirical results on FLAN-T5 across MNLI, TriviaQA, and WMT tasks show consistent gains over baselines, and mid-model embeddings provide additional improvements. The work suggests practical, black-box cascade strategies that leverage token-level uncertainty and cross-model signals to enable more efficient inference for advanced NLP tasks.

Abstract

Recent advances in language models (LMs) have led to significant improvements in quality on complex NLP tasks, but at the expense of increased inference costs. Cascading offers a simple strategy to achieve more favorable cost-quality tradeoffs: here, a small model is invoked for most "easy" instances, while a few "hard" instances are deferred to the large model. While the principles underpinning cascading are well-studied for classification tasks - with deferral based on predicted class uncertainty favored theoretically and practically - a similar understanding is lacking for generative LM tasks. In this work, we initiate a systematic study of deferral rules for LM cascades. We begin by examining the natural extension of predicted class uncertainty to generative LM tasks, namely, the predicted sequence uncertainty. We show that this measure suffers from the length bias problem, either over- or under-emphasizing outputs based on their lengths. This is because LMs produce a sequence of uncertainty values, one for each output token; and moreover, the number of output tokens is variable across examples. To mitigate this issue, we propose to exploit the richer token-level uncertainty information implicit in generative LMs. We argue that naive predicted sequence uncertainty corresponds to a simple aggregation of these uncertainties. By contrast, we show that incorporating token-level uncertainty through learned post-hoc deferral rules can significantly outperform such simple aggregation strategies, via experiments on a range of natural language benchmarks with FLAN-T5 models. We further show that incorporating embeddings from the smaller model and intermediate layers of the larger model can give an additional boost in the overall cost-quality tradeoff.
Paper Structure (26 sections, 12 equations, 19 figures, 4 tables)

This paper contains 26 sections, 12 equations, 19 figures, 4 tables.

Figures (19)

  • Figure 1: (a) In cascades, small models are used for easy instances whereas hard instances are routed to larger models. For generative LMs, the key challenge is to design a deferral rule based on uncertainties from multiple tokens. (b) Standard baselines which take the product and geometric mean of the probabilities are affected by the length of the output and perform sub-optimally. (c) Our proposed solution captures nuanced per-token uncertainty and outperforms both baselines.
  • Figure 2: Example of tokenized FLAN-T5 Base model output on WMT FR $\to$ EN. Red tokens have a significantly higher uncertainty compared to the others, as shown in the left plot. (For each red token, we note the rank of its uncertainty score in the right plot.) However, due to the large number of other more predictable tokens, Chow-Sum gives the output a relatively high score.
  • Figure 3: Deferral curves on MNLI, TriviaQA, and WMT DE $\to$ FR for a FLAN-T5 Base $\to$ Large cascade. Chow-Quantile consistently outperforms Chow-Sum and Chow-Average. This confirms there is value in going beyond naïve sequence probability as an uncertainty measure for cascading.
  • Figure 4: (a) Relation between deferral rules and output length (number of tokens) for WMT FR $\rightarrow$ EN dataset and FLAN-T5 Base Model. Chow-Sum tends to defer longer prompts: the prompts with lowest scores have notably higher length than those with higher scores. Interestingly, Chow-Averageover-corrects this bias: it tends to overly defer prompts with lower length. Chow-Quantile-0 again defers longer outputs more whereas Chow-Quantile-0.8 initially focuses more on the shorter outputs. Oracle refers to deferring using the difference of BLEURT scores of the two models. Oracle also tends to defer longer outputs, but the preference is moderate as compared to Chow-Sum. (b) Corresponding deferral curves. (c) Analysis of token-level uncertainty on WMT FR $\to$ EN. For each token index $i$, the corresponding average prediction probability across all examples (with prediction length $\leq i$) for FLAN-T5 Base. We observe that later tokens tend to have higher probability, i.e., the model is generally the most uncertain for early tokens.
  • Figure 5: FLAN-T5 Base predictions on WMT FR $\to$ EN. Top: Predictions with the longest lengths. These tend to have repetitions, indicating low quality output that could be resolved with a larger model; length does have some signal in identification of good candidates for deferral. Middle: The predictions which Chow-Quantile-0 tends to defer. This quantile tends to identify repetitions and "??" (unknown tokens) as these tokens tend to have lower probability. Bottom: The predictions which Chow-Quantile-0.8 tends to defer. This quantile prioritizes deferring shorter inputs.
  • ...and 14 more figures