Communication-Efficient Multi-Device Inference Acceleration for Transformer Models
Xiao Liu, Lijun Zhang, Deepak Ganesan, Hui Guan
TL;DR
This paper tackles the latency bottleneck of Transformer inference in bandwidth-constrained multi-device settings by identifying inter-device communication as the dominant factor and introducing ASTRA, a framework that fuses sequence parallelism with a Mixed-Precision Attention mechanism. By compressing non-local token embeddings via vector quantization and training with Noise-Augmented Quantization plus a Distributed Class Token scheme, ASTRA preserves accuracy while reducing inter-device data exchange. Empirical results on ViT and GPT-2 demonstrate up to $2.64\times$ speedups over single-device and up to $15.25\times$ over state-of-the-art multi-device methods at bandwidths as low as 10 Mbps, with ViT and GPT-2 tasks showing strong resilience under compression. The work provides a practical approach for deploying large Transformer models in edge and wireless environments, with open-source release for reproducibility.
Abstract
Transformer models power many AI applications but suffer from high inference latency, limiting their use in real-time settings. Multi-device inference can reduce latency by parallelizing computation. Yet, existing methods require high inter-device bandwidth, making them impractical for bandwidth-constrained environments. We propose ASTRA, a communication-efficient framework that accelerates Transformer inference through a novel integration of sequence parallelism and a Mixed-Precision Attention mechanism designed to minimize inter-device communication. ASTRA compresses non-local token embeddings via vector quantization and preserves task accuracy through two optimizations, Noise-Augmented Quantization and Distributed Class Tokens. Experiments on ViT and GPT2 across vision and NLP tasks show that ASTRA achieves up to 2.64X speedups over single-device inference and up to 15.25X speedups over state-of-the-art multi-device inferences, while operating under bandwidths as low as 10 Mbps. ASTRA is open-sourced at https://github.com/xl1990/Astra.
