Table of Contents
Fetching ...

Compiler-First State Space Duality and Portable $O(1)$ Autoregressive Caching for Inference

Cosmo Santoni

TL;DR

It is shown that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required.

Abstract

State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical $O(1)$ state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill ($15%$ MFU) and up to $64%$ bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.

Compiler-First State Space Duality and Portable $O(1)$ Autoregressive Caching for Inference

TL;DR

It is shown that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required.

Abstract

State-space model releases are typically coupled to fused CUDA and Triton kernels, inheriting a hard dependency on NVIDIA hardware. We show that Mamba-2's state space duality algorithm -- diagonal state structure, chunkable recurrence, and einsum-dominated compute with static control flow -- maps cleanly onto what XLA's fusion and tiling passes actually optimise, making custom kernels optional rather than required. We implement the full inference path (prefill, cached autoregressive decoding) as shaped standard primitives under XLA, without hand-written kernels, and realise the architecture's theoretical state management as a compiled on-device cache requiring no host synchronisation during generation. The implementation runs unmodified on CPU, NVIDIA GPU, and Google Cloud TPU from a single JAX source. On TPU v6e across five model scales (130M--2.7B parameters), XLA-generated code reaches approximately 140 TFLOPS on single-stream prefill ( MFU) and up to bandwidth utilisation on decode. Greedy decoding matches the PyTorch/CUDA reference token-for-token across 64 steps, with hidden-state agreement within float32 rounding tolerance. The pattern transfers to any SSM recurrence satisfying the same structural conditions, on any platform with a mature XLA backend. The implementation is publicly available at https://github.com/CosmoNaught/mamba2-jax and merged into the Bonsai JAX model library.
Paper Structure (43 sections, 5 equations, 6 figures, 13 tables, 2 algorithms)

This paper contains 43 sections, 5 equations, 6 figures, 13 tables, 2 algorithms.

Figures (6)

  • Figure 1: Autoregressive generation on Cloud TPU v6e across five model scales and six sequence lengths. (a) Speedup from caching. (b) Generation latency: cached (solid) grows linearly, non-cached (dashed) grows quadratically. (c) Per-step throughput: cached throughput is flat regardless of sequence length; non-cached throughput collapses.
  • Figure 2: Peak memory during autoregressive generation on Cloud TPU v6e. Cached decoding (solid) is constant; non-cached (dashed) grows linearly. At sequence length 4096, the 2.7B non-cached path consumes over 16 GB versus a constant 10.9 GB cached.
  • Figure 3: Hardware utilisation on Cloud TPU v6e. (a) Prefill MFU versus model size at three prompt lengths. (b) Decode HBU versus sequence length. HBU varies by less than 1.7 percentage points across sequence lengths for every model.
  • Figure 4: Fraction of hardware peak on Cloud TPU v6e. Solid bars: best prefill MFU (% of 918 TFLOPS); faded bars: mean decode HBU (% of 1600 GB/s). Utilisation increases with model size in both regimes.
  • Figure 5: Decode strategy comparison on Cloud TPU v6e. For small models (130M, 370M), the on-device fori_loop delivers substantially higher throughput. For larger models ($\geq$780M), per-step compute dominates and the paths converge.
  • ...and 1 more figures