Tensor-Parallelism with Partially Synchronized Activations
Itay Lamprecht, Asaf Karnieli, Yair Hanani, Niv Giladi, Daniel Soudry
TL;DR
The paper tackles the substantial communication overhead inherent in tensor-parallel LLM training and inference. It introduces CAAT-Net, a partial activation synchronization architecture that uses a partial channel-reduce to decrease communication without altering total compute, supported by adjustments to the backward pass and 32-bit gradient accumulation. Empirical results on 7B, 1.1B, and 130M models demonstrate up to 50% reduction in tensor-parallel communication with negligible accuracy loss on most benchmarks, along with notable training and inference speedups. The findings suggest a practical path toward more scalable and cost-effective large-model training and serving across diverse hardware.
Abstract
Training and inference of Large Language Models (LLMs) with tensor-parallelism requires substantial communication to synchronize activations. Our findings suggest that with a few minor adjustments to current practices, LLMs can be trained without fully synchronizing activations, reducing bandwidth demands. We name this "Communication-Aware Architecture for Tensor-parallelism" (CAAT-Net). We train a 7B parameter CAAT-Net model and show that tensor-parallel communication can be reduced by up to 50% with no significant drop in pretraining accuracy across nearly all evaluated benchmarks. We also experiment with smaller 130M and 1.1B models to show the robustness and scalability of our method. We find that, in some scenarios, validation loss can even improve when reducing communication. Finally, we demonstrate how CAAT-Net accelerates both training and inference workloads across various settings and model sizes.
