Table of Contents
Fetching ...

Sequential-Parallel Duality in Prefix Scannable Models

Morris Yau, Sharut Gupta, Valerie Engelmayer, Kazuki Irie, Stefanie Jegelka, Jacob Andreas

TL;DR

This work defines a more general class, Prefix-Scannable Models (PSMs), by relaxing the state aggregation operator to allow arbitrary (potentially non-associative) functions such as softmax attention, and empirically evaluates such models on illustrative small-scale language modeling and canonical synthetic tasks, including state tracking and associative recall.

Abstract

Modern neural sequence models are designed to meet the dual mandate of parallelizable training and fast sequential inference. Recent developments have given rise to various models, such as Gated Linear Attention (GLA) and Mamba, that achieve such ``sequential-parallel duality.'' This raises a natural question: can we characterize the full class of neural sequence models that support near-constant-time parallel evaluation and linear-time, constant-space sequential inference? We begin by describing a broad class of such models -- state space models -- as those whose state updates can be computed using the classic parallel prefix scan algorithm with a custom associative aggregation operator. We then define a more general class, Prefix-Scannable Models (PSMs), by relaxing the state aggregation operator to allow arbitrary (potentially non-associative) functions such as softmax attention. This generalization unifies many existing architectures, including element-wise RNNs (e.g., Mamba) and linear transformers (e.g., GLA, Mamba2, mLSTM), while also introducing new models with softmax-like operators that achieve O(1) amortized compute per token and log(N) memory for sequence length N. We empirically evaluate such models on illustrative small-scale language modeling and canonical synthetic tasks, including state tracking and associative recall. Empirically, we find that PSMs retain the expressivity of transformer-based architectures while matching the inference efficiency of state space models -- in some cases exhibiting better length generalization than either.

Sequential-Parallel Duality in Prefix Scannable Models

TL;DR

This work defines a more general class, Prefix-Scannable Models (PSMs), by relaxing the state aggregation operator to allow arbitrary (potentially non-associative) functions such as softmax attention, and empirically evaluates such models on illustrative small-scale language modeling and canonical synthetic tasks, including state tracking and associative recall.

Abstract

Modern neural sequence models are designed to meet the dual mandate of parallelizable training and fast sequential inference. Recent developments have given rise to various models, such as Gated Linear Attention (GLA) and Mamba, that achieve such ``sequential-parallel duality.'' This raises a natural question: can we characterize the full class of neural sequence models that support near-constant-time parallel evaluation and linear-time, constant-space sequential inference? We begin by describing a broad class of such models -- state space models -- as those whose state updates can be computed using the classic parallel prefix scan algorithm with a custom associative aggregation operator. We then define a more general class, Prefix-Scannable Models (PSMs), by relaxing the state aggregation operator to allow arbitrary (potentially non-associative) functions such as softmax attention. This generalization unifies many existing architectures, including element-wise RNNs (e.g., Mamba) and linear transformers (e.g., GLA, Mamba2, mLSTM), while also introducing new models with softmax-like operators that achieve O(1) amortized compute per token and log(N) memory for sequence length N. We empirically evaluate such models on illustrative small-scale language modeling and canonical synthetic tasks, including state tracking and associative recall. Empirically, we find that PSMs retain the expressivity of transformer-based architectures while matching the inference efficiency of state space models -- in some cases exhibiting better length generalization than either.

Paper Structure

This paper contains 19 sections, 11 theorems, 30 equations, 6 figures, 1 table, 4 algorithms.

Key Result

Proposition 3.2

Every Prefix–Scannable Model is in the class ${\normalfont \textsf{SPD‑}}(n,\,\log n).$ That is, its training work is $\Theta(n)$ with parallel depth $\tilde{O}(1)$, while online inference runs in $O(1)$ amortised time and $O(\log n)$ memory per token.

Figures (6)

  • Figure 1: An illustration of the Blelloch parallel scan used to compute prefix states in Prefix-Scannable Models (PSMs). Here the input has 16 tokens grouped into 8 chunks $\{{\bm{x}}[0],\dots,{\bm{x}}[7]\}$ (see (a) bottom), and the goal is to produce prefix states $\{{\bm{e}},{\bm{x}}[0],{\bm{x}}[0\mathpunct{:}\!1],\dots,{\bm{x}}[0\mathpunct{:}\!6]\}$, where ${\bm{x}}[i\mathpunct{:}\!j]$ aggregates all tokens from chunks $i$ to $j$, and ${\bm{e}}$ is the identity. (a) In the upsweep, chunks are aggregated along a binary tree through a series of chunk aggregation operations (solid arrows), producing intermediate values and some of the final prefix states (e.g., ${\bm{x}}[0\mathpunct{:}\!1],{\bm{x}}[0\mathpunct{:}\!3]$). (b) In the downsweep, the missing prefix states are filled in by propagating values backward: ${\bm{x}}[0\mathpunct{:}\!7]$ is reset to ${\bm{e}}$, and copy (dotted arrows) and aggregation (solid arrows) operations complete the sequence. When each chunk is treated as an atomic element, this recovers the classic Blelloch scan.
  • Figure 2: An illustration of the autoregressive state computation of "Transformer-PSM" (\ref{['sec:exp']}) at inference time. Here the model uses a chunk size of 2. From left to right, a single new token is fed to the model at a time. Two first figures: when predicting tokens in chunk ${\bm{x}}[2]$, the model only requires tokens from the prefix state ${\bm{x}}[0\mathpunct{:}\!1]$ and those within ${\bm{x}}[2]$. Third figure: predicting tokens in chunk ${\bm{x}}[3]$ requires the prefix state ${\bm{x}}[0\mathpunct{:}\!1]$, and chunks ${\bm{x}}[2]$ and ${\bm{x}}[3]$. Last figure: once all tokens in chunk ${\bm{x}}[3]$ are processed, a new prefix state ${\bm{x}}[0\mathpunct{:}\!3]$ is computed, which is later used to predict tokens in ${\bm{x}}[4]$, and so on. Prefix state ${\bm{s}}_i$ corresponds to ${\bm{s}}_i = {\bm{x}}[0\mathpunct{:}\!i]$.
  • Figure 3: Error rate on the state tracking $S_5$ task. After training on sequences with lengths up to 18, Transformer-PSM generalizes to more than 160 tokens, far beyond Transformer and Mamba.
  • Figure 4: Error rate on MQAR of Transformer-PSM (T-PSM), Sliding Window Transformer (SWT) and Mamba. Evaluated lengths are in-distribution.
  • Figure 5: Evaluation perplexity of Transformer-PSM with varying chunk sizes on WikiText-103
  • ...and 1 more figures

Theorems & Definitions (30)

  • Definition 2.1: State kernel
  • Definition 2.2: Inference module
  • Definition 2.3: Sequence model
  • Definition 2.4: Parallel circuit family
  • Definition 2.5: Sequential–Parallel Duality $\textsf{SPD}\bigl(T(n),\,m(n)\bigr)$
  • Definition 3.1: Prefix–Scannable Model
  • Proposition 3.2
  • proof : Proof Sketch
  • Definition 3.3: Affine recurrence
  • Lemma 3.4: Affine aggregator
  • ...and 20 more