Table of Contents
Fetching ...

Early-Exit Neural Networks with Nested Prediction Sets

Metod Jazbec, Patrick Forré, Stephan Mandt, Dan Zhang, Eric Nalisnick

TL;DR

The paper tackles the challenge of quantifying predictive uncertainty in early-exit neural networks (EENNs) without sacrificing the anytime guarantees required for safety-critical settings. It introduces anytime-valid confidence sequences (AVCSs) as a principled mechanism to produce nested prediction sets across exits, addressing the non-nested behavior seen with standard conformal or Bayesian approaches. The authors develop both regression (Bayesian linear regression with AVCS) and classification (Dirichlet Prior Networks with AVCS) implementations, including practical approximations and stability diagnostics. Empirically, they demonstrate perfect nestedness and competitive marginal coverage across synthetic, NLP, and image-classification tasks, while highlighting trade-offs in initial interval sizes and the potential for abrupt collapses to empty intervals for out-of-distribution inputs. Overall, the work advances fast, resource-aware, and provably safe predictive inference for sequential exit points in deep models, with broad applicability to safety-critical AI deployments.

Abstract

Early-exit neural networks (EENNs) enable adaptive and efficient inference by providing predictions at multiple stages during the forward pass. In safety-critical applications, these predictions are meaningful only when accompanied by reliable uncertainty estimates. A popular method for quantifying the uncertainty of predictive models is the use of prediction sets. However, we demonstrate that standard techniques such as conformal prediction and Bayesian credible sets are not suitable for EENNs. They tend to generate non-nested sets across exits, meaning that labels deemed improbable at one exit may reappear in the prediction set of a subsequent exit. To address this issue, we investigate anytime-valid confidence sequences (AVCSs), an extension of traditional confidence intervals tailored for data-streaming scenarios. These sequences are inherently nested and thus well-suited for an EENN's sequential predictions. We explore the theoretical and practical challenges of using AVCSs in EENNs and show that they indeed yield nested sets across exits. Thus our work presents a promising approach towards fast, yet still safe, predictive modeling

Early-Exit Neural Networks with Nested Prediction Sets

TL;DR

The paper tackles the challenge of quantifying predictive uncertainty in early-exit neural networks (EENNs) without sacrificing the anytime guarantees required for safety-critical settings. It introduces anytime-valid confidence sequences (AVCSs) as a principled mechanism to produce nested prediction sets across exits, addressing the non-nested behavior seen with standard conformal or Bayesian approaches. The authors develop both regression (Bayesian linear regression with AVCS) and classification (Dirichlet Prior Networks with AVCS) implementations, including practical approximations and stability diagnostics. Empirically, they demonstrate perfect nestedness and competitive marginal coverage across synthetic, NLP, and image-classification tasks, while highlighting trade-offs in initial interval sizes and the potential for abrupt collapses to empty intervals for out-of-distribution inputs. Overall, the work advances fast, resource-aware, and provably safe predictive inference for sequential exit points in deep models, with broad applicability to safety-critical AI deployments.

Abstract

Early-exit neural networks (EENNs) enable adaptive and efficient inference by providing predictions at multiple stages during the forward pass. In safety-critical applications, these predictions are meaningful only when accompanied by reliable uncertainty estimates. A popular method for quantifying the uncertainty of predictive models is the use of prediction sets. However, we demonstrate that standard techniques such as conformal prediction and Bayesian credible sets are not suitable for EENNs. They tend to generate non-nested sets across exits, meaning that labels deemed improbable at one exit may reappear in the prediction set of a subsequent exit. To address this issue, we investigate anytime-valid confidence sequences (AVCSs), an extension of traditional confidence intervals tailored for data-streaming scenarios. These sequences are inherently nested and thus well-suited for an EENN's sequential predictions. We explore the theoretical and practical challenges of using AVCSs in EENNs and show that they indeed yield nested sets across exits. Thus our work presents a promising approach towards fast, yet still safe, predictive modeling
Paper Structure (47 sections, 2 theorems, 30 equations, 10 figures, 2 algorithms)

This paper contains 47 sections, 2 theorems, 30 equations, 10 figures, 2 algorithms.

Key Result

Proposition 1

For a given test point $(\bm{x}^*, y^*)$, the predictive-likelihood ratio $R_t^*(y)$ in (eq:eenn_avcs_ratio) is a non-negative martingale with $R_0^* = 1$ when evaluated at $y= y^*$. Moreover, the prediction sets of the form $C_t^* \coloneqq \{ y \in \mathcal{Y} \: | \: R_t^*(y) \le 1 / \alpha\}$ a

Figures (10)

  • Figure 1: Illustrative example of a 1-dimensional regression problem using an Early-Exit neural network (EENN) with $T=5$ exits. Upper: At each exit, the EENN produces a prediction interval $C_t$ nested within its previous estimates, i.e., $C_t \subseteq C_{t-1}$. Lower: An example of non-nested prediction intervals across different exits, e.g., $C_2$ contains candidate labels $y$ not included in $C_1$ (area denoted with () lines). Such behavior often results from an EENN becoming overconfident, i.e., exhibiting low uncertainty, too early.
  • Figure 2: We compare our EENN-AVCS with EENN-Bayes baseline based on average nestedness (top), marginal coverage (middle), and average interval size (bottom). EENN-AVCS is the only approach that yields perfect nestedness while maintaining reasonably high marginal coverage across exits. The nestedness comes at a price of larger intervals in the initial exits, though. Note that in the top plot, the nestedness curves of EENN-AVCS () and EENN-Bayes-intersection () overlap at $\mathfrak{N}(t) = 1$.
  • Figure 3: Prediction intervals (*□) for EENN-Bayes (left) and our EENN-AVCS (right) on two simulated regression tasks antoran2020depth: wiggle (up) and 3-clusters (bottom). Blue points denote training data. In cases where the EENN-AVCS collapses to an empty set (out-of-distribution), we do not depict anything, which explains the gaps in EENN-AVCS predictions. We set the significance level to $\alpha = 0.05$ for EENN-AVCS, while for EENN-Bayes, we plot intervals that capture 2 standard deviations away from the predicted mean (). With different background colors we denote different regions of data distribution, see Section \ref{['sec:exp_synthetic']}.
  • Figure 4: Average epistemic uncertainty $v_*$ () across Bayesian linear regression models at different exits. As expected, $v_*$ is larger in the regions where we observe less training data: out-of-distribution (denoted with a white background) and in-between (denoted with a grey background *□). Hence, $v_*$ can serve as an indicator for assessing the reliability of EENN-AVCSs.
  • Figure 5: Comparison of our EENN-AVCS with CQR romano2019conformalized and EENN-Bayes baselines on the NLP regression datasets. Similar to findings on the synthetic data (c.f., Figure \ref{['fig:toy_data_metrics']}), EENN-AVCS attains perfect nestedness (upper plot) while maintaining reasonably high marginal coverage across exits (middle plot). However, the intervals generated by EENN-AVCS at each exit are larger compared to the baseline (bottom row). Note that in the upper plot, the nestedness curves of EENN-AVCS (), EENN-Bayes-intersection (), and EENN-CQR-intersection () overlap at $\mathfrak{N}(t) = 1$.
  • ...and 5 more figures

Theorems & Definitions (2)

  • Proposition 1
  • Proposition 2