Table of Contents
Fetching ...

Tensor-Parallelism with Partially Synchronized Activations

Itay Lamprecht, Asaf Karnieli, Yair Hanani, Niv Giladi, Daniel Soudry

TL;DR

The paper tackles the substantial communication overhead inherent in tensor-parallel LLM training and inference. It introduces CAAT-Net, a partial activation synchronization architecture that uses a partial channel-reduce to decrease communication without altering total compute, supported by adjustments to the backward pass and 32-bit gradient accumulation. Empirical results on 7B, 1.1B, and 130M models demonstrate up to 50% reduction in tensor-parallel communication with negligible accuracy loss on most benchmarks, along with notable training and inference speedups. The findings suggest a practical path toward more scalable and cost-effective large-model training and serving across diverse hardware.

Abstract

Training and inference of Large Language Models (LLMs) with tensor-parallelism requires substantial communication to synchronize activations. Our findings suggest that with a few minor adjustments to current practices, LLMs can be trained without fully synchronizing activations, reducing bandwidth demands. We name this "Communication-Aware Architecture for Tensor-parallelism" (CAAT-Net). We train a 7B parameter CAAT-Net model and show that tensor-parallel communication can be reduced by up to 50% with no significant drop in pretraining accuracy across nearly all evaluated benchmarks. We also experiment with smaller 130M and 1.1B models to show the robustness and scalability of our method. We find that, in some scenarios, validation loss can even improve when reducing communication. Finally, we demonstrate how CAAT-Net accelerates both training and inference workloads across various settings and model sizes.

Tensor-Parallelism with Partially Synchronized Activations

TL;DR

The paper tackles the substantial communication overhead inherent in tensor-parallel LLM training and inference. It introduces CAAT-Net, a partial activation synchronization architecture that uses a partial channel-reduce to decrease communication without altering total compute, supported by adjustments to the backward pass and 32-bit gradient accumulation. Empirical results on 7B, 1.1B, and 130M models demonstrate up to 50% reduction in tensor-parallel communication with negligible accuracy loss on most benchmarks, along with notable training and inference speedups. The findings suggest a practical path toward more scalable and cost-effective large-model training and serving across diverse hardware.

Abstract

Training and inference of Large Language Models (LLMs) with tensor-parallelism requires substantial communication to synchronize activations. Our findings suggest that with a few minor adjustments to current practices, LLMs can be trained without fully synchronizing activations, reducing bandwidth demands. We name this "Communication-Aware Architecture for Tensor-parallelism" (CAAT-Net). We train a 7B parameter CAAT-Net model and show that tensor-parallel communication can be reduced by up to 50% with no significant drop in pretraining accuracy across nearly all evaluated benchmarks. We also experiment with smaller 130M and 1.1B models to show the robustness and scalability of our method. We find that, in some scenarios, validation loss can even improve when reducing communication. Finally, we demonstrate how CAAT-Net accelerates both training and inference workloads across various settings and model sizes.

Paper Structure

This paper contains 24 sections, 35 equations, 7 figures, 9 tables.

Figures (7)

  • Figure 1: CAAT-Net model architecture.Left.We exemplify our approach on a two-layer fully-connected neural network on a single device. Middle. When using tensor-parallelism with two devices, the input activation $X$ is identical on both devices. Each device uses its own set of weights to multiply the inputs, yielding intermediate activations, which are then reduced into identical copies of $Z$ on both devices. Right. CAAT-Net receives different input activations, $X_1$ and $X_2$, for Device 1 and Device 2, respectively. These yield intermediate activations that are partially synced between the devices, producing $Z_1$ and $Z_2$ on Device 1 and Device 2, respectively. Private channels are marked $P$, and shared channels are marked $S$.
  • Figure 2: Partial synchronization and partial channel--reduce.(a) Vanilla transformers in current training frameworks. The operation $f$ is an all-reduce in the forward pass and identity in the backward pass. The operation $g$ is an all-reduce in the backward pass and identity in the forward pass. (b) With partial synchronization, $h$ denotes the reduction operation in both the forward and the backward pass, since both must be done at the same location. Synchronization of the normalization function parameters is necessary. (c) Partial channel--reduce with parameter $p$ over 2 devices.
  • Figure 3: Training accuracy in multiple scenarios. Left. Validation loss of 130M and 1.1B models for different values of $p$, and of the 7B model with $p=0.5$, normalized to the loss at $p=1$. Right. Validation loss for the 130M model with varying values of $p$ and tensor-parallel dimension (TP).
  • Figure 4: Speedup in training and inference for a Llama 7B model. Top Left Training speedup for varying values of $p$ and tensor-parallel dimension. Top Right. Inference Time-To-First-Token (TTFT) speedup on different hardware, using batch-size 128. Bottom Left. Inference TTFT speedup using tensor-parallel 8 as a function of batch size, for different values of $p$. Bottom Right. Inference TTFT speedup using tensor-parallel 16 as a function of batch size, for different values of $p$.
  • Figure 5: Comparison to compression methods. Validation loss vs. $p$ for 130M models using CAAT-Net, Top-k masking and random masking.
  • ...and 2 more figures