Leveraging Stochastic Depth Training for Adaptive Inference
Guilherme Korol, Antonio Carlos Schneider Beck, Jeronimo Castrillon
TL;DR
The paper tackles the challenge of adaptive, per-input inference at the edge by exploiting models trained with Stochastic Depth to become resilient to arbitrary layer skipping. It replaces costly, learnable gating or decision networks with a zero-overhead gating mechanism and a design-time Pareto-front exploration of skipping configurations, enabling a single-model, time-predictable inference. At runtime, a lightweight algorithm selects among Pareto configurations based on current edge conditions, achieving significant improvements in energy efficiency and throughput (up to 2× power savings and nearly 2× more inferences in some cases) while maintaining acceptable accuracy. This approach offers predictable, controllable adaptive inference suitable for resource-constrained edge deployments and demonstrates concrete gains on CIFAR-10/100 with ResNets, compiled via IREE for edge devices. The work bridges stochastic-depth training with practical, zero-overhead adaptive inference, enabling deployment of large models with tunable accuracy-time-energy profiles.
Abstract
Dynamic DNN optimization techniques such as layer-skipping offer increased adaptability and efficiency gains but can lead to i) a larger memory footprint as in decision gates, ii) increased training complexity (e.g., with non-differentiable operations), and iii) less control over performance-quality trade-offs due to its inherent input-dependent execution. To approach these issues, we propose a simpler yet effective alternative for adaptive inference with a zero-overhead, single-model, and time-predictable inference. Central to our approach is the observation that models trained with Stochastic Depth -- a method for faster training of residual networks -- become more resilient to arbitrary layer-skipping at inference time. We propose a method to first select near Pareto-optimal skipping configurations from a stochastically-trained model to adapt the inference at runtime later. Compared to original ResNets, our method shows improvements of up to 2X in power efficiency at accuracy drops as low as 0.71%.
