Early-Exit Neural Networks with Nested Prediction Sets
Metod Jazbec, Patrick Forré, Stephan Mandt, Dan Zhang, Eric Nalisnick
TL;DR
The paper tackles the challenge of quantifying predictive uncertainty in early-exit neural networks (EENNs) without sacrificing the anytime guarantees required for safety-critical settings. It introduces anytime-valid confidence sequences (AVCSs) as a principled mechanism to produce nested prediction sets across exits, addressing the non-nested behavior seen with standard conformal or Bayesian approaches. The authors develop both regression (Bayesian linear regression with AVCS) and classification (Dirichlet Prior Networks with AVCS) implementations, including practical approximations and stability diagnostics. Empirically, they demonstrate perfect nestedness and competitive marginal coverage across synthetic, NLP, and image-classification tasks, while highlighting trade-offs in initial interval sizes and the potential for abrupt collapses to empty intervals for out-of-distribution inputs. Overall, the work advances fast, resource-aware, and provably safe predictive inference for sequential exit points in deep models, with broad applicability to safety-critical AI deployments.
Abstract
Early-exit neural networks (EENNs) enable adaptive and efficient inference by providing predictions at multiple stages during the forward pass. In safety-critical applications, these predictions are meaningful only when accompanied by reliable uncertainty estimates. A popular method for quantifying the uncertainty of predictive models is the use of prediction sets. However, we demonstrate that standard techniques such as conformal prediction and Bayesian credible sets are not suitable for EENNs. They tend to generate non-nested sets across exits, meaning that labels deemed improbable at one exit may reappear in the prediction set of a subsequent exit. To address this issue, we investigate anytime-valid confidence sequences (AVCSs), an extension of traditional confidence intervals tailored for data-streaming scenarios. These sequences are inherently nested and thus well-suited for an EENN's sequential predictions. We explore the theoretical and practical challenges of using AVCSs in EENNs and show that they indeed yield nested sets across exits. Thus our work presents a promising approach towards fast, yet still safe, predictive modeling
