Table of Contents
Fetching ...

TawPipe: Topology-Aware Weight Pipeline Parallelism for Accelerating Long-Context Large Models Training

Houming Wu, Ling Chen

TL;DR

TawPipe tackles the dual bottleneck of memory and inter-device communication in long-context LLM training by introducing topology-aware weight pipeline parallelism. It integrates three core ideas—Device-Bound Storage to fix weight shards per device, Group-Based Weight Pipeline Scheduler to maximize intra-node bandwidth and minimize cross-node transfers, and Communication-Computation Overlap to hide latency—thereby bridging weight-passing approaches with topology-aware data movement. Theoretical analysis and extensive experiments on up to 24 GPUs show TawPipe achieves higher throughput with modest memory overhead, outperforming both activation-passing and prior weight-passing baselines across long-context configurations. Practically, TawPipe reduces cross-node traffic and improves scalability for large-scale distributed training of long-context models such as LLaMA-family variants.

Abstract

Training large language models (LLMs) is fundamentally constrained by limited device memory and costly inter-device communication. Although pipeline parallelism alleviates memory pressure by partitioning models across devices, it incurs activation communication overhead that scales linearly with sequence length, limiting efficiency in long-context training. Recent weight-passing approaches (e.g., WeiPipe) mitigate this by transmitting model weights instead of activations, but suffer from redundant peer-to-peer (P2P) transfers and underutilized intra-node bandwidth. We propose TawPipe--topology-aware weight pipeline parallelism, which exploits hierarchical bandwidth in distributed clusters for improved communication efficiency. TawPipe: (i) groups devices based on topology to optimize intra-node collective and inter-node P2P communication; (ii) assigns each device a fixed shard of model weights and gradients, avoiding redundant transfers; and (iii) overlaps communication with computation to hide latency. Unlike global collective operations used in fully sharded data parallelism (FSDP), TawPipe confines most communication within node boundaries, significantly reducing cross-node traffic. Extensive experiments on up to 24 GPUs with LLaMA-style models show that TawPipe achieves superior throughput and scalability compared to state-of-the-art baselines.

TawPipe: Topology-Aware Weight Pipeline Parallelism for Accelerating Long-Context Large Models Training

TL;DR

TawPipe tackles the dual bottleneck of memory and inter-device communication in long-context LLM training by introducing topology-aware weight pipeline parallelism. It integrates three core ideas—Device-Bound Storage to fix weight shards per device, Group-Based Weight Pipeline Scheduler to maximize intra-node bandwidth and minimize cross-node transfers, and Communication-Computation Overlap to hide latency—thereby bridging weight-passing approaches with topology-aware data movement. Theoretical analysis and extensive experiments on up to 24 GPUs show TawPipe achieves higher throughput with modest memory overhead, outperforming both activation-passing and prior weight-passing baselines across long-context configurations. Practically, TawPipe reduces cross-node traffic and improves scalability for large-scale distributed training of long-context models such as LLaMA-family variants.

Abstract

Training large language models (LLMs) is fundamentally constrained by limited device memory and costly inter-device communication. Although pipeline parallelism alleviates memory pressure by partitioning models across devices, it incurs activation communication overhead that scales linearly with sequence length, limiting efficiency in long-context training. Recent weight-passing approaches (e.g., WeiPipe) mitigate this by transmitting model weights instead of activations, but suffer from redundant peer-to-peer (P2P) transfers and underutilized intra-node bandwidth. We propose TawPipe--topology-aware weight pipeline parallelism, which exploits hierarchical bandwidth in distributed clusters for improved communication efficiency. TawPipe: (i) groups devices based on topology to optimize intra-node collective and inter-node P2P communication; (ii) assigns each device a fixed shard of model weights and gradients, avoiding redundant transfers; and (iii) overlaps communication with computation to hide latency. Unlike global collective operations used in fully sharded data parallelism (FSDP), TawPipe confines most communication within node boundaries, significantly reducing cross-node traffic. Extensive experiments on up to 24 GPUs with LLaMA-style models show that TawPipe achieves superior throughput and scalability compared to state-of-the-art baselines.

Paper Structure

This paper contains 18 sections, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Overview of the TawPipe design. DBS fixes weight shards on each device, GWPS aligns communication with the underlying hardware topology, and CCO overlaps inter-group prefetching with computation. “wp”, “wb”, “gp”, and “gr” denote weight passing, weight broadcasting, gradient passing, and gradient reduction, respectively.
  • Figure 2: Weight initialization under ring-based and device-bound strategies. Each set of three GPUs corresponds to a compute node. (a) Ring-based scheme needs to buffer and rotate two weight shards continuously. (b) Device-bound strategy stores one weight shard and initiates communication only when needed.
  • Figure 3: Group-based weight pipeline scheduling. The texts without background color represent device-bound data. The text with gray background color denotes buffered data in memory. For instance, at $t=4$, ${\rm P}_2$ holds $W_4$, and receives $W_5$ from ${\rm P}_5$.
  • Figure 4: Comparison of TawPipe with 1F1B and WeiPipe. In the pipelining graph, the upper number in a block represents the layer index and the bottom number is the micro-batch index. TawPipe has a lower bubble ratio as the pipeline flush happens sooner in the timeline.
  • Figure 5: Weak scaling. The number of GPUs scales from 8 to 24 (8 GPUs in a node) with global batch size increasing proportionally from 512 to 1536.
  • ...and 1 more figures