Table of Contents
Fetching ...

STree: Speculative Tree Decoding for Hybrid State-Space Models

Yangchao Wu, Zongyue Qin, Alex Wong, Stefano Soatto

TL;DR

STree introduces a scalable tree-based speculative decoding algorithm for state-space models (SSMs) and hybrid SSM-Transformer architectures, enabling efficient multi-step generation by accumulating state transitions along a token tree. It packs a prefix tree into a single sequence and constructs a topology mask $L$ to guide a computed $A_{tree}$, with $(A_{tree})_t = \sum_{i=1}^N L_{t,i} \operatorname{diag}(\log A_i)$ under a diagonal $A_i$ assumption, enabling $y_t = C_t \exp\{(A_{tree})_t\} x_0 + \sum_{s=1}^t L_{t,s} \exp\{(A_{tree})_t - (A_{tree})_s\} \circ (C_t B_s u_s)$. A hardware-aware tree-scan kernel with activation replay is implemented to minimize overhead. Empirical results on multiple benchmarks show STree outperforms vanilla speculative decoding with a baseline drafting model and static tree, and the relative overhead decreases for larger model sizes, indicating favorable scalability for bigger LLMs. The approach thus enables faster inference for hybrid architectures by exploiting tree-based verification without large memory penalties.

Abstract

Speculative decoding is a technique to leverage hardware concurrency in order to enable multiple steps of token generation in a single forward pass, thus improving the efficiency of large-scale autoregressive (AR) Transformer models. State-space models (SSMs) are already more efficient than AR Transformers, since their state summarizes all past data with no need to cache or re-process tokens in the sliding window context. However, their state can also comprise thousands of tokens; so, speculative decoding has recently been extended to SSMs. Existing approaches, however, do not leverage the tree-based verification methods, since current SSMs lack the means to compute a token tree efficiently. We propose the first scalable algorithm to perform tree-based speculative decoding in state-space models (SSMs) and hybrid architectures of SSMs and Transformer layers. We exploit the structure of accumulated state transition matrices to facilitate tree-based speculative decoding with minimal overhead relative to current SSM implementations. Along with the algorithm, we describe a hardware-aware implementation that improves naive application of AR Transformer tree-based speculative decoding methods to SSMs. Furthermore, we outperform vanilla speculative decoding with SSMs even with a baseline drafting model and tree structure on three different benchmarks, opening up opportunities for further speed up with SSM and hybrid model inference. Code can be found at: https://github.com/wyc1997/stree.

STree: Speculative Tree Decoding for Hybrid State-Space Models

TL;DR

STree introduces a scalable tree-based speculative decoding algorithm for state-space models (SSMs) and hybrid SSM-Transformer architectures, enabling efficient multi-step generation by accumulating state transitions along a token tree. It packs a prefix tree into a single sequence and constructs a topology mask to guide a computed , with under a diagonal assumption, enabling . A hardware-aware tree-scan kernel with activation replay is implemented to minimize overhead. Empirical results on multiple benchmarks show STree outperforms vanilla speculative decoding with a baseline drafting model and static tree, and the relative overhead decreases for larger model sizes, indicating favorable scalability for bigger LLMs. The approach thus enables faster inference for hybrid architectures by exploiting tree-based verification without large memory penalties.

Abstract

Speculative decoding is a technique to leverage hardware concurrency in order to enable multiple steps of token generation in a single forward pass, thus improving the efficiency of large-scale autoregressive (AR) Transformer models. State-space models (SSMs) are already more efficient than AR Transformers, since their state summarizes all past data with no need to cache or re-process tokens in the sliding window context. However, their state can also comprise thousands of tokens; so, speculative decoding has recently been extended to SSMs. Existing approaches, however, do not leverage the tree-based verification methods, since current SSMs lack the means to compute a token tree efficiently. We propose the first scalable algorithm to perform tree-based speculative decoding in state-space models (SSMs) and hybrid architectures of SSMs and Transformer layers. We exploit the structure of accumulated state transition matrices to facilitate tree-based speculative decoding with minimal overhead relative to current SSM implementations. Along with the algorithm, we describe a hardware-aware implementation that improves naive application of AR Transformer tree-based speculative decoding methods to SSMs. Furthermore, we outperform vanilla speculative decoding with SSMs even with a baseline drafting model and tree structure on three different benchmarks, opening up opportunities for further speed up with SSM and hybrid model inference. Code can be found at: https://github.com/wyc1997/stree.

Paper Structure

This paper contains 28 sections, 13 equations, 5 figures, 8 tables, 1 algorithm.

Figures (5)

  • Figure 1: Methods to decode a prefix token tree with AR Transformers. A prefix token tree can be unrolled into multiple sequences (used by current SSMs) and computed as a batch or packed into one sequence using a mask to indicate the tree structure (first used by SpecInfer miao2023specinfer and extended to SSMs here). The former leads to inefficiency due to repeatedly computed tokens (in orange).
  • Figure 2: Left: The runtime for a call to Mamba2-2.7B model vs. input size with STree and Fuse Selective Scan (FSS). A linear regression is performed to obtain the slope and intercept. Middle: The runtime for a call to MambaInLlama-8B model vs. input size with STree and FSS. A polynomial regression with degree 2 is used to obtain the parameters. Right: The ratio of acceptance length $\tau$ required to get wall-clock time improvement vs. input length $N_2$, with a fixed $N_1 = 5$
  • Figure 3: Left: Comparison of number of states required and number of tokens computed to get an output for a full binary tree in one forward pass. STree is able to decode a packed tree sequence, while other methods need to unroll the tree into multiple sequences. Right: Runtime in milliseconds (ms) for a forward pass for different full binary trees using different algorithms.
  • Figure 4: Left: Structure of static tree used to generate a prefix token tree with the drafting model. We draft 4 steps for each iteration and 3 tokens for each layer, resulting in 13 tokens in every input sequence. Middle: Generation speed of STree and Vanilla Specultaive Decoding (SD) under different temperature. Right: Average acceptance length of STree and Vanilla Speculative Decoding (SD) under different temperature.
  • Figure 5: Static tree structure that we used in our ablation study for effect of differnt tree structure.