Table of Contents
Fetching ...

TAGC: Optimizing Gradient Communication in Distributed Transformer Training

Igor Polyakov, Alexey Dukhanov, Egor Spirin

TL;DR

TAGC targets the gradient communication bottleneck in distributed transformer training by adapting lossless homomorphic compression to sharded models and incorporating transformer-specific optimizations, such as layer-selective compression and dynamic sparsification. Integrated into PyTorch FSDP, TAGC uses a hybrid communication strategy that keeps Index synchronization via All-Reduce while Count Sketch data is reduced via Reduce, enabling overlap of computation and communication. Empirical results show up to 15% end-to-end speedup under low-network bandwidth with modest model-quality degradation (~a few percent), particularly when compressing large non-attention linear layers. This approach offers a practical, architecture-aware path to faster large-scale transformer training with configurable trade-offs, and the authors provide public code for replication and further research.

Abstract

The increasing complexity of large language models (LLMs) necessitates efficient training strategies to mitigate the high computational costs associated with distributed training. A significant bottleneck in this process is gradient synchronization across multiple GPUs, particularly in the zero-redundancy parallelism mode. In this paper, we introduce Transformer-Aware Gradient Compression (TAGC), an optimized gradient compression algorithm designed specifically for transformer-based models. TAGC extends the lossless homomorphic compression method by adapting it for sharded models and incorporating transformer-specific optimizations, such as layer-selective compression and dynamic sparsification. Our experimental results demonstrate that TAGC accelerates training by up to 15% compared to the standard Fully Sharded Data Parallel (FSDP) approach, with minimal impact on model quality. We integrate TAGC into the PyTorch FSDP framework, the implementation is publicly available at https://github.com/ipolyakov/TAGC.

TAGC: Optimizing Gradient Communication in Distributed Transformer Training

TL;DR

TAGC targets the gradient communication bottleneck in distributed transformer training by adapting lossless homomorphic compression to sharded models and incorporating transformer-specific optimizations, such as layer-selective compression and dynamic sparsification. Integrated into PyTorch FSDP, TAGC uses a hybrid communication strategy that keeps Index synchronization via All-Reduce while Count Sketch data is reduced via Reduce, enabling overlap of computation and communication. Empirical results show up to 15% end-to-end speedup under low-network bandwidth with modest model-quality degradation (~a few percent), particularly when compressing large non-attention linear layers. This approach offers a practical, architecture-aware path to faster large-scale transformer training with configurable trade-offs, and the authors provide public code for replication and further research.

Abstract

The increasing complexity of large language models (LLMs) necessitates efficient training strategies to mitigate the high computational costs associated with distributed training. A significant bottleneck in this process is gradient synchronization across multiple GPUs, particularly in the zero-redundancy parallelism mode. In this paper, we introduce Transformer-Aware Gradient Compression (TAGC), an optimized gradient compression algorithm designed specifically for transformer-based models. TAGC extends the lossless homomorphic compression method by adapting it for sharded models and incorporating transformer-specific optimizations, such as layer-selective compression and dynamic sparsification. Our experimental results demonstrate that TAGC accelerates training by up to 15% compared to the standard Fully Sharded Data Parallel (FSDP) approach, with minimal impact on model quality. We integrate TAGC into the PyTorch FSDP framework, the implementation is publicly available at https://github.com/ipolyakov/TAGC.

Paper Structure

This paper contains 15 sections, 4 figures, 2 tables, 1 algorithm.

Figures (4)

  • Figure 1: Communication amount per parameter per rank in bits for LHC and TAGC algorithms for 1-bit and 4-bit Index. Count Sketch stores compressed parameters, 2x compression leads to 16-bit per parameter. (a) 4-bit Index for LHC and TAGC algorithms. (b) 1-bit Index for LHC and TAGC algorithms.
  • Figure 2: Communication amount per parameter per rank in bits for various compression configurations in TAGC.
  • Figure 3: CUDA stream profiles for backward step for 2 layers placed in consecutive FSDP units. Layers are attention projection and feed-forward up-projection. Each row represents a computation or communication stream in an execution timeline. For TAGC, communication time is $32.1\%$ shorter than that of the baseline.
  • Figure 4: Validation loss by iteration number for various TAGC configurations and plain FSDP baseline.