Table of Contents
Fetching ...

Jointly-Learned Exit and Inference for a Dynamic Neural Network : JEI-DNN

Florence Regol, Joud Chataoui, Mark Coates

TL;DR

This work tackles the inefficiency of large pretrained models by proposing JEI-DNN, a jointly trained, dynamic exit-and-inference framework. By attaching lightweight, trainable intermediate inference modules (IMs) and gates to a fixed backbone and modeling exit probabilities with a sequential, learnable gate model, JEI-DNN uses a bi-level optimization to align gate decisions with the evolving IMs and costs. The approach yields improved accuracy-cost trade-offs and significantly better uncertainty characterization, including calibrated probabilities and tighter conformal intervals, outperforming architecture-agnostic baselines and enhancing state-of-the-art EDNNs on multiple datasets. Its practical impact lies in enabling reliable, cost-aware inference with off-the-shelf backbones across vision tasks, while providing robust uncertainty estimates for decision-making under resource constraints.

Abstract

Large pretrained models, coupled with fine-tuning, are slowly becoming established as the dominant architecture in machine learning. Even though these models offer impressive performance, their practical application is often limited by the prohibitive amount of resources required for every inference. Early-exiting dynamic neural networks (EDNN) circumvent this issue by allowing a model to make some of its predictions from intermediate layers (i.e., early-exit). Training an EDNN architecture is challenging as it consists of two intertwined components: the gating mechanism (GM) that controls early-exiting decisions and the intermediate inference modules (IMs) that perform inference from intermediate representations. As a result, most existing approaches rely on thresholding confidence metrics for the gating mechanism and strive to improve the underlying backbone network and the inference modules. Although successful, this approach has two fundamental shortcomings: 1) the GMs and the IMs are decoupled during training, leading to a train-test mismatch; and 2) the thresholding gating mechanism introduces a positive bias into the predictive probabilities, making it difficult to readily extract uncertainty information. We propose a novel architecture that connects these two modules. This leads to significant performance improvements on classification datasets and enables better uncertainty characterization capabilities.

Jointly-Learned Exit and Inference for a Dynamic Neural Network : JEI-DNN

TL;DR

This work tackles the inefficiency of large pretrained models by proposing JEI-DNN, a jointly trained, dynamic exit-and-inference framework. By attaching lightweight, trainable intermediate inference modules (IMs) and gates to a fixed backbone and modeling exit probabilities with a sequential, learnable gate model, JEI-DNN uses a bi-level optimization to align gate decisions with the evolving IMs and costs. The approach yields improved accuracy-cost trade-offs and significantly better uncertainty characterization, including calibrated probabilities and tighter conformal intervals, outperforming architecture-agnostic baselines and enhancing state-of-the-art EDNNs on multiple datasets. Its practical impact lies in enabling reliable, cost-aware inference with off-the-shelf backbones across vision tasks, while providing robust uncertainty estimates for decision-making under resource constraints.

Abstract

Large pretrained models, coupled with fine-tuning, are slowly becoming established as the dominant architecture in machine learning. Even though these models offer impressive performance, their practical application is often limited by the prohibitive amount of resources required for every inference. Early-exiting dynamic neural networks (EDNN) circumvent this issue by allowing a model to make some of its predictions from intermediate layers (i.e., early-exit). Training an EDNN architecture is challenging as it consists of two intertwined components: the gating mechanism (GM) that controls early-exiting decisions and the intermediate inference modules (IMs) that perform inference from intermediate representations. As a result, most existing approaches rely on thresholding confidence metrics for the gating mechanism and strive to improve the underlying backbone network and the inference modules. Although successful, this approach has two fundamental shortcomings: 1) the GMs and the IMs are decoupled during training, leading to a train-test mismatch; and 2) the thresholding gating mechanism introduces a positive bias into the predictive probabilities, making it difficult to readily extract uncertainty information. We propose a novel architecture that connects these two modules. This leads to significant performance improvements on classification datasets and enables better uncertainty characterization capabilities.
Paper Structure (41 sections, 25 equations, 21 figures, 4 tables, 1 algorithm)

This paper contains 41 sections, 25 equations, 21 figures, 4 tables, 1 algorithm.

Figures (21)

  • Figure 1: Left: Modelling of GM and how early exiting is achieved. Right Illustration of the mutual influence of GM and IM.
  • Figure 2: Accuracy vs Mul-Add of: Left CIFAR100 (t2t-14); Middle SVHN (t2t-7); and Right CIFAR100LT (t2t-14). The x-axes are scaled by the full model inference cost, Mul-Add ($IC^L$).
  • Figure 3: Decomposition of the contributions of the IMs to the final accuracies (depicted by dotted lines) for CIFAR100. The operational point is marked by a star in the left panel of Figure \ref{['fig:acc_vs_ece']}. a) Accuracy of each IM, $f^l_{\theta}$, evaluated only on their exited samples ($\mathcal{D}^l$). b) Accuracy of each IM, $f^l_{\theta}$, on the full dataset $\mathcal{D}$. The size of a circle is proportional to the number of samples exited for a trial.
  • Figure 4: ECE of the IMs averaged over all baselines on all datasets with t2t-7.
  • Figure 5: Uncertainty metrics on CIFAR-100 (t2t-14). a) ECE vs Mul-Add. b) Inefficiency vs Mul-Add for an empirical coverage bounded by $1-\hat{\alpha} \geq 95\%$ on CIFAR-100 (t2t-14). c) Average empirical $\hat{\alpha}$ vs requested $\alpha$. Appendix \ref{['sec:add_results']} presents results on other datasets and architectures.
  • ...and 16 more figures