AdaPerceiver: Transformers with Adaptive Width, Depth, and Tokens
Purvish Jajal, Nick John Eliopoulos, Benjamin Shiue-Hal Chou, George K. Thiruvathukal, Yung-Hsiang Lu, James C. Davis
TL;DR
AdaPerceiver introduces a unified transformer that adapts along three axes—tokens, depth, and width—at inference time. It achieves this with a latent-stream design, block-masked attention, and Matryoshka FFNs, trained via a once-for-all, joint objective that optimizes multiple configurations in a single forward pass. Empirical results across ImageNet-1K classification, ADE20K segmentation, and NYUv2 depth estimation show improved accuracy–throughput Pareto fronts and substantial encoder FLOP reductions compared to strong baselines. The work demonstrates practical gains for deployment under diverse hardware and latency constraints, and highlights policy-driven adaptivity as a viable route to further efficiency gains.
Abstract
Modern transformer architectures achieve remarkable performance across tasks and domains but remain rigid in how they allocate computation at inference time. Real-world deployment often requires models to adapt to diverse hardware and latency constraints, yet most approaches to dynamic computation focus on a single axis -- such as reducing the number of tokens. We present a novel capability: AdaPerceiver, the first transformer architecture with unified adaptivity across depth, width, and tokens within a single model. We propose an architecture that supports adaptivity along these axes. We couple this with an efficient joint training regime that ensures the model maintains performance across its various configurations. We evaluate AdaPerceiver on image classification, semantic segmentation, and depth estimation tasks. On image classification, AdaPerceiver expands the accuracy-throughput Pareto front. It achieves 85.4% accuracy while yielding 36% higher throughput than FlexiViT-L. On dense prediction, AdaPerceiver matches ViT-H/14 while having $\sim$26x fewer encoder FLOPs (floating-point operations) on semantic segmentation and depth estimation. Finally, we show how AdaPerceiver equipped with a policy can maintain ImageNet1K accuracy ($\pm0.1$ percentage points) while reducing FLOPs by $24-33$%.
