HydraViT: Stacking Heads for a Scalable ViT
Janek Haberer, Ali Hojjat, Olaf Landsiedel
TL;DR
HydraViT tackles the challenge of deploying Vision Transformers on devices with diverse resource profiles by introducing a universal ViT that intrinsically contains multiple subnetworks. It trains a single model with $H$ attention heads and embedding dimension $E$ using a stochastic scheme that repeatedly activates the first $k$ heads and corresponding embedding slices, enabling runtime selection of the subnetwork based on available hardware. The approach is augmented with optional separate classifiers and a weighted subnetwork sampling strategy to improve per-subnetwork accuracy. Experimental results on ImageNet-1K show HydraViT can match or exceed the accuracy of independently trained DeiT tiny/small/base models at the same GMACs or throughput, and can support up to 10 subnetworks, with robustness across several ImageNet variants. While HydraViT increases training complexity due to multi-subnetwork optimization, it reduces total training time relative to training multiple distinct models and offers a practical pathway for scalable, device-aware ViT deployment.
Abstract
The architecture of Vision Transformers (ViTs), particularly the Multi-head Attention (MHA) mechanism, imposes substantial hardware demands. Deploying ViTs on devices with varying constraints, such as mobile phones, requires multiple models of different sizes. However, this approach has limitations, such as training and storing each required model separately. This paper introduces HydraViT, a novel approach that addresses these limitations by stacking attention heads to achieve a scalable ViT. By repeatedly changing the size of the embedded dimensions throughout each layer and their corresponding number of attention heads in MHA during training, HydraViT induces multiple subnetworks. Thereby, HydraViT achieves adaptability across a wide spectrum of hardware environments while maintaining performance. Our experimental results demonstrate the efficacy of HydraViT in achieving a scalable ViT with up to 10 subnetworks, covering a wide range of resource constraints. HydraViT achieves up to 5 p.p. more accuracy with the same GMACs and up to 7 p.p. more accuracy with the same throughput on ImageNet-1K compared to the baselines, making it an effective solution for scenarios where hardware availability is diverse or varies over time. Source code available at https://github.com/ds-kiel/HydraViT.
