Table of Contents
Fetching ...

Communication-Efficient Multi-Device Inference Acceleration for Transformer Models

Xiao Liu, Lijun Zhang, Deepak Ganesan, Hui Guan

TL;DR

This paper tackles the latency bottleneck of Transformer inference in bandwidth-constrained multi-device settings by identifying inter-device communication as the dominant factor and introducing ASTRA, a framework that fuses sequence parallelism with a Mixed-Precision Attention mechanism. By compressing non-local token embeddings via vector quantization and training with Noise-Augmented Quantization plus a Distributed Class Token scheme, ASTRA preserves accuracy while reducing inter-device data exchange. Empirical results on ViT and GPT-2 demonstrate up to $2.64\times$ speedups over single-device and up to $15.25\times$ over state-of-the-art multi-device methods at bandwidths as low as 10 Mbps, with ViT and GPT-2 tasks showing strong resilience under compression. The work provides a practical approach for deploying large Transformer models in edge and wireless environments, with open-source release for reproducibility.

Abstract

Transformer models power many AI applications but suffer from high inference latency, limiting their use in real-time settings. Multi-device inference can reduce latency by parallelizing computation. Yet, existing methods require high inter-device bandwidth, making them impractical for bandwidth-constrained environments. We propose ASTRA, a communication-efficient framework that accelerates Transformer inference through a novel integration of sequence parallelism and a Mixed-Precision Attention mechanism designed to minimize inter-device communication. ASTRA compresses non-local token embeddings via vector quantization and preserves task accuracy through two optimizations, Noise-Augmented Quantization and Distributed Class Tokens. Experiments on ViT and GPT2 across vision and NLP tasks show that ASTRA achieves up to 2.64X speedups over single-device inference and up to 15.25X speedups over state-of-the-art multi-device inferences, while operating under bandwidths as low as 10 Mbps. ASTRA is open-sourced at https://github.com/xl1990/Astra.

Communication-Efficient Multi-Device Inference Acceleration for Transformer Models

TL;DR

This paper tackles the latency bottleneck of Transformer inference in bandwidth-constrained multi-device settings by identifying inter-device communication as the dominant factor and introducing ASTRA, a framework that fuses sequence parallelism with a Mixed-Precision Attention mechanism. By compressing non-local token embeddings via vector quantization and training with Noise-Augmented Quantization plus a Distributed Class Token scheme, ASTRA preserves accuracy while reducing inter-device data exchange. Empirical results on ViT and GPT-2 demonstrate up to speedups over single-device and up to over state-of-the-art multi-device methods at bandwidths as low as 10 Mbps, with ViT and GPT-2 tasks showing strong resilience under compression. The work provides a practical approach for deploying large Transformer models in edge and wireless environments, with open-source release for reproducibility.

Abstract

Transformer models power many AI applications but suffer from high inference latency, limiting their use in real-time settings. Multi-device inference can reduce latency by parallelizing computation. Yet, existing methods require high inter-device bandwidth, making them impractical for bandwidth-constrained environments. We propose ASTRA, a communication-efficient framework that accelerates Transformer inference through a novel integration of sequence parallelism and a Mixed-Precision Attention mechanism designed to minimize inter-device communication. ASTRA compresses non-local token embeddings via vector quantization and preserves task accuracy through two optimizations, Noise-Augmented Quantization and Distributed Class Tokens. Experiments on ViT and GPT2 across vision and NLP tasks show that ASTRA achieves up to 2.64X speedups over single-device inference and up to 15.25X speedups over state-of-the-art multi-device inferences, while operating under bandwidths as low as 10 Mbps. ASTRA is open-sourced at https://github.com/xl1990/Astra.

Paper Structure

This paper contains 18 sections, 4 theorems, 31 equations, 10 figures, 7 tables.

Key Result

Theorem 1

Let $\hat{\mathbf{X}}$ denote the quantized embedding of $\mathbf{X}$, and let $\tilde{\mathbf{X}} = \hat{\mathbf{X}} + \lambda \xi$ with $\xi \sim \mathcal{N}(\boldsymbol{\mu}, \boldsymbol{\Sigma})$ sampled from the quantization residuals. Then the 2-Wasserstein distance between the original embedd i.e., the noise-augmented distribution is statistically closer to the true distribution than the ra

Figures (10)

  • Figure 1: Latency speedup for existing multi-device inference methods and our proposed method Astra with different groups $G$ under different bandwidths with 4 devices and 1024 input tokens. BP: Block Parallelism, SP: Sequence Parallelism, TP: Tensor Parallelism, AG: Allgather. A smaller $N_b$ means fewer communication for BP.
  • Figure 1: Task accuracy and communication overhead on CIFAR-100 and ImageNet-1K with ViT-Base.
  • Figure 2: Overview of Astra with two devices. We introduce three key innovations: (1) Mixed-Precision Attention, (2) Noise-Augmented Vector Quantization, and (3) Distributed Class Tokens to achieve communication-efficient multi-device inference. Astra can be applied to transformers for both deterministic and generative tasks.
  • Figure 3: Latency breakdown of local computation and inter-device communication time. The red dashed line represents the single-device latency.
  • Figure 4: Speedup comparison under different numbers of devices (w/ 1024 tokens).
  • ...and 5 more figures

Theorems & Definitions (6)

  • Theorem 1: Noise-Augmented Embeddings Improve Distributional Fidelity
  • Theorem 2: Variance Reduction via Distributed Class Tokens
  • Theorem 1: Noise-Augmented Embeddings Improve Distributional Fidelity
  • proof : Proof of Theorem \ref{['thm:noise_dis']}
  • Theorem 2: Variance Reduction via Distributed Class Tokens
  • proof : Proof of Theorem \ref{['thm:discls']}