Table of Contents
Fetching ...

Scaling State-Space Models on Multiple GPUs with Tensor Parallelism

Anurag Dutt, Nimit Shah, Hazem Masarani, Anshul Gandhi

TL;DR

This paper presents a communication-efficient TP design for selective SSM inference that addresses three practical engineering challenges: enabling TTFT improvements via an SSM state cache across prefill and decode, partitioning the mixer's packed parameter tensor so that recurrent updates remain local while minimizing communication, and reducing TP aggregation overhead with quantized AllReduce.

Abstract

Selective state space models (SSMs) have rapidly become a compelling backbone for large language models, especially for long-context workloads. Yet in deployment, their inference performance is often bounded by the memory capacity, bandwidth, and latency limits of a single GPU, making multi-GPU execution increasingly necessary. Although tensor parallelism (TP) is widely used to scale Transformer inference, applying it to selective SSM blocks is non-trivial because the SSM mixer couples large projections with a sequence-wise recurrent state update and local mixing whose efficiency depends on preserving locality and avoiding synchronization in the critical path. This paper presents a communication-efficient TP design for selective SSM inference that addresses three practical engineering challenges: enabling TTFT improvements via an SSM state cache across prefill and decode, partitioning the mixer's packed parameter tensor so that recurrent updates remain local while minimizing communication, and reducing TP aggregation overhead with quantized AllReduce. We evaluate on three representative SSM-based LLMs spanning pure-SSM and hybrid architectures - Mamba, Falcon-Mamba, and Zamba - on NVIDIA A6000 and A100 clusters. Our experiments show substantial throughput gains from tensor-parallel SSM inference, improving batch-request throughput by ~1.6-2.1x on 2 GPUs and ~2.6-4.0x on 4 GPUs for Mamba, with the largest benefits at long context lengths, and achieving a further ~10-18% throughput improvement from quantized all-reduce by lowering synchronization bandwidth overhead.

Scaling State-Space Models on Multiple GPUs with Tensor Parallelism

TL;DR

This paper presents a communication-efficient TP design for selective SSM inference that addresses three practical engineering challenges: enabling TTFT improvements via an SSM state cache across prefill and decode, partitioning the mixer's packed parameter tensor so that recurrent updates remain local while minimizing communication, and reducing TP aggregation overhead with quantized AllReduce.

Abstract

Selective state space models (SSMs) have rapidly become a compelling backbone for large language models, especially for long-context workloads. Yet in deployment, their inference performance is often bounded by the memory capacity, bandwidth, and latency limits of a single GPU, making multi-GPU execution increasingly necessary. Although tensor parallelism (TP) is widely used to scale Transformer inference, applying it to selective SSM blocks is non-trivial because the SSM mixer couples large projections with a sequence-wise recurrent state update and local mixing whose efficiency depends on preserving locality and avoiding synchronization in the critical path. This paper presents a communication-efficient TP design for selective SSM inference that addresses three practical engineering challenges: enabling TTFT improvements via an SSM state cache across prefill and decode, partitioning the mixer's packed parameter tensor so that recurrent updates remain local while minimizing communication, and reducing TP aggregation overhead with quantized AllReduce. We evaluate on three representative SSM-based LLMs spanning pure-SSM and hybrid architectures - Mamba, Falcon-Mamba, and Zamba - on NVIDIA A6000 and A100 clusters. Our experiments show substantial throughput gains from tensor-parallel SSM inference, improving batch-request throughput by ~1.6-2.1x on 2 GPUs and ~2.6-4.0x on 4 GPUs for Mamba, with the largest benefits at long context lengths, and achieving a further ~10-18% throughput improvement from quantized all-reduce by lowering synchronization bandwidth overhead.
Paper Structure (24 sections, 10 figures, 1 table)

This paper contains 24 sections, 10 figures, 1 table.

Figures (10)

  • Figure 1: Mamba-style SSM mixer block. The input projection produces a packed tensor that is split into an SSM path and a gating path. The SSM path applies a channel-separable 1D convolution, projects activations into token-dependent SSM fields ($\Delta$, $B$, $C$) together with per-channel parameters ($A$, $D$), performs the state update over the sequence, gates the result, and applies an output projection back to the residual stream.
  • Figure 2: Illustration of our tensor-parallel inference implementation for Mamba, consisting of our four key design components: (1) SSM cache, (2) channel splitter and depthwise convolution, (3) packed parameter handling, and (4) AllReduce quantization.
  • Figure 3: Maximum input sequence length possible at fixed (256) batch size under our TP design and under DP and 1x.
  • Figure 4: Throughput gains afforded by our tensor-parallel inference design (for 2-GPU, 4-GPU) for Mamba, Mamba-2, Falcon-Mamba, and Zamba, compared to no parallelism ("1x", first row) and compared to data parallelism ("DP", second row) on the A6000 cluster.
  • Figure 5: Throughput gains afforded by our tensor-parallel inference design (for 2-GPU, 4-GPU) for Mamba, Mamba-2, Falcon-Mamba, and Zamba, compared to no parallelism ("1x", first row) and compared to data parallelism ("DP", second row) on the A100 cluster.
  • ...and 5 more figures