Table of Contents
Fetching ...

Fast yet Safe: Early-Exiting with Risk Control

Metod Jazbec, Alexander Timans, Tin Hadži Veljković, Kaspar Sakmann, Dan Zhang, Christian A. Naesseth, Eric Nalisnick

TL;DR

This work investigates how to adapt frameworks of risk control to EENNs, and empirically validate the insights on a range of vision and language tasks, demonstrating that risk control can produce substantial computational savings, all the while preserving user-specified performance goals.

Abstract

Scaling machine learning models significantly improves their performance. However, such gains come at the cost of inference being slow and resource-intensive. Early-exit neural networks (EENNs) offer a promising solution: they accelerate inference by allowing intermediate layers to exit and produce a prediction early. Yet a fundamental issue with EENNs is how to determine when to exit without severely degrading performance. In other words, when is it 'safe' for an EENN to go 'fast'? To address this issue, we investigate how to adapt frameworks of risk control to EENNs. Risk control offers a distribution-free, post-hoc solution that tunes the EENN's exiting mechanism so that exits only occur when the output is of sufficient quality. We empirically validate our insights on a range of vision and language tasks, demonstrating that risk control can produce substantial computational savings, all the while preserving user-specified performance goals.

Fast yet Safe: Early-Exiting with Risk Control

TL;DR

This work investigates how to adapt frameworks of risk control to EENNs, and empirically validate the insights on a range of vision and language tasks, demonstrating that risk control can produce substantial computational savings, all the while preserving user-specified performance goals.

Abstract

Scaling machine learning models significantly improves their performance. However, such gains come at the cost of inference being slow and resource-intensive. Early-exit neural networks (EENNs) offer a promising solution: they accelerate inference by allowing intermediate layers to exit and produce a prediction early. Yet a fundamental issue with EENNs is how to determine when to exit without severely degrading performance. In other words, when is it 'safe' for an EENN to go 'fast'? To address this issue, we investigate how to adapt frameworks of risk control to EENNs. Risk control offers a distribution-free, post-hoc solution that tunes the EENN's exiting mechanism so that exits only occur when the output is of sufficient quality. We empirically validate our insights on a range of vision and language tasks, demonstrating that risk control can produce substantial computational savings, all the while preserving user-specified performance goals.
Paper Structure (54 sections, 4 theorems, 24 equations, 13 figures, 4 tables, 2 algorithms)

This paper contains 54 sections, 4 theorems, 24 equations, 13 figures, 4 tables, 2 algorithms.

Key Result

Proposition 1

Let $\ell: \Lambda \rightarrow (-\infty, B]$ be a right-continuous bounded loss, and assume a marginally monotone EENN (eq:marg_mono). Then the exit threshold $\hat{\lambda}_{\text{CRC}}$ ensures risk control in expectation, i.e., it holds that $\;\; \mathbb{E}_{\mathcal{D}_{cal} \sim \mathcal{P}^n}

Figures (13)

  • Figure 1: Accuracy and Brier score brier1950verification across exits for different EENNs for image classification on ImageNet (\ref{['subsec:exp-classif']}). Marginally monotone performance trends (\ref{['eq:marg_mono']}) are generally observed across models, with last-layer exits performing best.
  • Figure 2: Empirical test risk (top) and efficiency gains (bottom) for the CALM model schuster2022confident for text summarization on CNN/DM. Our adaptation of UCB bates2021distribution (\ref{['prop:ucb']}) outperforms the LTT angelopoulos2021learn approach in CALM by yielding larger efficiency gains under the same risk control assurances (see \ref{['sec:exp-calm']} for details). Shading denotes the standard deviation across $S=100$ calibration/test splits.
  • Figure 3: Empirical test risk (top) and efficiency gains (bottom) for different early-exit models, risks (\ref{['subsec:methods-risks']}) and risk levels $\epsilon$ on ImageNet (for calibration set size $n = 100$). In line with theoretical results, the test risk is controlled across models, risk types, and levels. Despite guaranteeing control in expectation (CRC, \ref{['prop:crc']}) or with high probability (UCB, \ref{['prop:ucb']}), obtained gains are substantial.
  • Figure 4: Right: Example of our method's early-exiting on Cityscapes cordts2016cityscapes. For two samples that exit early ($l=1$) and exit late ($l=4$), we display ground truth segmentation masks and confidence maps at the first and last model layer. Left: For every sample, we compute the Brier loss difference $\Delta\ell_{B} = |\ell_{B}(\hat{p}_1(\mathbf{y} | \bm{x}), \bm{y}) - \ell_{B}(\hat{p}_4(\mathbf{y} | \bm{x}), \bm{y})|$ between first and last model layer (\ref{['eq:brier']}), and stratify values across respective exit layers; the red dot denotes the mean. For both figures, we consider the simplest combination of Top-1 confidence score and mean image-level aggregation (for $\epsilon=0.08$).
  • Figure 5: Results for early-exit diffusion with DeeDiff tang2023deediff on CelebA liu2015deep. Left: The quality of generated images is directly related to the target risk control level $\epsilon$. Right: Empirical test risks are controlled for both CRC (\ref{['prop:crc']}) and UCB (\ref{['prop:ucb']}) (for calibration set size $n=500$).
  • ...and 8 more figures

Theorems & Definitions (9)

  • Proposition 1
  • Proposition 2
  • Lemma 1
  • proof
  • proof
  • proof
  • Proposition 3
  • proof
  • Remark 1