Table of Contents
Fetching ...

Federated Learning for Collaborative Inference Systems: The Case of Early Exit Networks

Caelin Kaplan, Angelo Rodio, Tareq Si Salem, Chuan Xu, Giovanni Neglia

TL;DR

This paper tackles training for Collaborative Inference Systems (CISs) under heterogeneous serving-rate conditions by introducing an inference-aware Federated Learning framework. It formalizes a weighted objective that ties exit-level losses to expected future inference requests, and proposes a training algorithm that uses participation probabilities across node-exit pairs to optimize a global model. The theoretical analysis decomposes the true training error into generalization, bias, and optimization components, and provides practical configuration rules (Serving Rate and Balanced Adj) to minimize these errors under realistic CIS dynamics. Extensive CIFAR-10/100 experiments with ResNet-18-based Early Exit Networks demonstrate that accounting for inference load and gradient variance yields robust improvements over traditional baselines, especially when small devices shoulder most requests. The approach generalizes to various nested training strategies (e.g., pruning, ordered dropout) and offers actionable guidance for deploying inference-aware FL in heterogeneous CIS deployments.

Abstract

As Internet of Things (IoT) technology advances, end devices like sensors and smartphones are progressively equipped with AI models tailored to their local memory and computational constraints. Local inference reduces communication costs and latency; however, these smaller models typically underperform compared to more sophisticated models deployed on edge servers or in the cloud. Cooperative Inference Systems (CISs) address this performance trade-off by enabling smaller devices to offload part of their inference tasks to more capable devices. These systems often deploy hierarchical models that share numerous parameters, exemplified by Deep Neural Networks (DNNs) that utilize strategies like early exits or ordered dropout. In such instances, Federated Learning (FL) may be employed to jointly train the models within a CIS. Yet, traditional training methods have overlooked the operational dynamics of CISs during inference, particularly the potential high heterogeneity in serving rates across clients. To address this gap, we propose a novel FL approach designed explicitly for use in CISs that accounts for these variations in serving rates. Our framework not only offers rigorous theoretical guarantees, but also surpasses state-of-the-art (SOTA) training algorithms for CISs, especially in scenarios where inference request rates or data availability are uneven among clients.

Federated Learning for Collaborative Inference Systems: The Case of Early Exit Networks

TL;DR

This paper tackles training for Collaborative Inference Systems (CISs) under heterogeneous serving-rate conditions by introducing an inference-aware Federated Learning framework. It formalizes a weighted objective that ties exit-level losses to expected future inference requests, and proposes a training algorithm that uses participation probabilities across node-exit pairs to optimize a global model. The theoretical analysis decomposes the true training error into generalization, bias, and optimization components, and provides practical configuration rules (Serving Rate and Balanced Adj) to minimize these errors under realistic CIS dynamics. Extensive CIFAR-10/100 experiments with ResNet-18-based Early Exit Networks demonstrate that accounting for inference load and gradient variance yields robust improvements over traditional baselines, especially when small devices shoulder most requests. The approach generalizes to various nested training strategies (e.g., pruning, ordered dropout) and offers actionable guidance for deploying inference-aware FL in heterogeneous CIS deployments.

Abstract

As Internet of Things (IoT) technology advances, end devices like sensors and smartphones are progressively equipped with AI models tailored to their local memory and computational constraints. Local inference reduces communication costs and latency; however, these smaller models typically underperform compared to more sophisticated models deployed on edge servers or in the cloud. Cooperative Inference Systems (CISs) address this performance trade-off by enabling smaller devices to offload part of their inference tasks to more capable devices. These systems often deploy hierarchical models that share numerous parameters, exemplified by Deep Neural Networks (DNNs) that utilize strategies like early exits or ordered dropout. In such instances, Federated Learning (FL) may be employed to jointly train the models within a CIS. Yet, traditional training methods have overlooked the operational dynamics of CISs during inference, particularly the potential high heterogeneity in serving rates across clients. To address this gap, we propose a novel FL approach designed explicitly for use in CISs that accounts for these variations in serving rates. Our framework not only offers rigorous theoretical guarantees, but also surpasses state-of-the-art (SOTA) training algorithms for CISs, especially in scenarios where inference request rates or data availability are uneven among clients.
Paper Structure (31 sections, 6 theorems, 34 equations, 2 figures, 6 tables, 1 algorithm)

This paper contains 31 sections, 6 theorems, 34 equations, 2 figures, 6 tables, 1 algorithm.

Key Result

Theorem 1

Under Assumptions asm:bounded_loss--asm:sgd_var, the true error of the output $\boldsymbol{w}^{(T)}$ of Alg. alg:fed-cis with learning rate $\eta^{(t,j)}= \frac{2}{\mu (\gamma + (t-1)J+j+1)}$ and $\gamma\triangleq\max\{8 \kappa, J\}-1$ can be bounded as follows: where $\kappa \triangleq \frac{L}{\mu}$, $Pdim(H_e)$ represents the pseudo-dimension of the class of models for exit $e$, $\mathop{\math

Figures (2)

  • Figure 1: Early Exit Networks for Collaborative Inference System. An input sample is first passed through the initial layers of the DNN until it reaches Exit $1$. If the measure of prediction uncertainty is below the threshold $T_1$, the prediction is served at the current node. Otherwise, the intermediate representation of the current input is transferred to a node with greater computational capacity, and inference continues. This process repeats until the prediction uncertainty is below $T_e$ or the final Exit $E$ is reached.
  • Figure 2: An example of a two-layer network with four nodes: Node $0$, Node $1$, and Node $2$ each receive local requests, $\lambda^{a}_i$ (in requests per second, r/s), serve a portion locally, $\lambda^s_{i}$, and transfer the remainder, $\lambda^{t}_{i}$, to their parent. Node $3$ receives requests both locally and from its children, and serves all requests as it has no parent.

Theorems & Definitions (11)

  • Theorem 1
  • Theorem 1
  • proof
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • Theorem 2
  • ...and 1 more