Distributed On-Device LLM Inference With Over-the-Air Computation
Kai Zhang, Hengtao He, Shenghui Song, Jun Zhang, Khaled B. Letaief
TL;DR
The paper tackles the challenge of deploying large language models on resource-constrained edge devices by distributing inference across devices via tensor parallelism and aggregating via over-the-air computation. It formulates a mixed-timescale stochastic non-convex optimization problem $ \mathcal{P} $ to minimize the average transmission mean-squared error (MSE) with a long-term model assignment $\mathbf{m}$ and fast-timescale transceiver beamformers $\mathbf{A}, \{\mathbf{B}_n\}$, and develops a SDR-based short-term algorithm along with a stochastic SCA long-term procedure. The solution couples semidefinite relaxation and Gaussian randomization to obtain feasible transceivers and employs iterative surrogate optimization for model assignment, updated with diminishing steps. Simulations using LLaMA2/3 models demonstrate significant latency reductions and maintained or improved inference accuracy, indicating practical viability of distributed on-device LLM inference over wireless networks.
Abstract
Large language models (LLMs) have achieved remarkable success across various artificial intelligence tasks. However, their enormous sizes and computational demands pose significant challenges for the deployment on edge devices. To address this issue, we present a distributed on-device LLM inference framework based on tensor parallelism, which partitions neural network tensors (e.g., weight matrices) of LLMs among multiple edge devices for collaborative inference. Nevertheless, tensor parallelism involves frequent all-reduce operations to aggregate intermediate layer outputs across participating devices during inference, resulting in substantial communication overhead. To mitigate this bottleneck, we propose an over-the-air computation method that leverages the analog superposition property of wireless multiple-access channels to facilitate fast all-reduce operations. To minimize the average transmission mean-squared error, we investigate joint model assignment and transceiver optimization, which can be formulated as a mixed-timescale stochastic non-convex optimization problem. Then, we develop a mixed-timescale algorithm leveraging semidefinite relaxation and stochastic successive convex approximation methods. Comprehensive simulation results will show that the proposed approach significantly reduces inference latency while improving accuracy. This makes distributed on-device LLM inference practical for resource-constrained edge devices.
