Table of Contents
Fetching ...

PRISM: Parallel Residual Iterative Sequence Model

Jie Jiang, Ke Cheng, Xin Xu, Mengyang Pang, Tianhao Lu, Jiaheng Li, Yue Liu, Yuan Wang, Jun Zhang, Huan Yu, Zhouchen Lin

TL;DR

PRISM addresses the expressivity-efficiency trade-off in generative sequence modeling by substituting serial, multi-step optimization with an amortized, input-anchored refinement that yields Rank-$L$ updates within a parallelizable linear recurrence. It introduces Write-Forget Decoupling to keep forgetting linear and injection high-rank, and employs an Input-Anchored ShortConv proxy plus a learned predictor to synthesize multi-step refinements in parallel. Theoretical analysis shows Rank Accumulation and robust stability of the forgetting path, while experiments demonstrate competitive accuracy to explicit solvers and substantial throughput gains (up to 174x) on long-horizon recommendation benchmarks. These results indicate that distilling iterative optimization trajectories into parallelizable, input-grounded operators is a promising approach for scalable, high-fidelity foundation models.

Abstract

Generative sequence modeling faces a fundamental tension between the expressivity of Transformers and the efficiency of linear sequence models. Existing efficient architectures are theoretically bounded by shallow, single-step linear updates, while powerful iterative methods like Test-Time Training (TTT) break hardware parallelism due to state-dependent gradients. We propose PRISM (Parallel Residual Iterative Sequence Model) to resolve this tension. PRISM introduces a solver-inspired inductive bias that captures key structural properties of multi-step refinement in a parallelizable form. We employ a Write-Forget Decoupling strategy that isolates non-linearity within the injection operator. To bypass the serial dependency of explicit solvers, PRISM utilizes a two-stage proxy architecture: a short-convolution anchors the initial residual using local history energy, while a learned predictor estimates the refinement updates directly from the input. This design distills structural patterns associated with iterative correction into a parallelizable feedforward operator. Theoretically, we prove that this formulation achieves Rank-$L$ accumulation, structurally expanding the update manifold beyond the single-step Rank-$1$ bottleneck. Empirically, it achieves comparable performance to explicit optimization methods while achieving 174x higher throughput.

PRISM: Parallel Residual Iterative Sequence Model

TL;DR

PRISM addresses the expressivity-efficiency trade-off in generative sequence modeling by substituting serial, multi-step optimization with an amortized, input-anchored refinement that yields Rank- updates within a parallelizable linear recurrence. It introduces Write-Forget Decoupling to keep forgetting linear and injection high-rank, and employs an Input-Anchored ShortConv proxy plus a learned predictor to synthesize multi-step refinements in parallel. Theoretical analysis shows Rank Accumulation and robust stability of the forgetting path, while experiments demonstrate competitive accuracy to explicit solvers and substantial throughput gains (up to 174x) on long-horizon recommendation benchmarks. These results indicate that distilling iterative optimization trajectories into parallelizable, input-grounded operators is a promising approach for scalable, high-fidelity foundation models.

Abstract

Generative sequence modeling faces a fundamental tension between the expressivity of Transformers and the efficiency of linear sequence models. Existing efficient architectures are theoretically bounded by shallow, single-step linear updates, while powerful iterative methods like Test-Time Training (TTT) break hardware parallelism due to state-dependent gradients. We propose PRISM (Parallel Residual Iterative Sequence Model) to resolve this tension. PRISM introduces a solver-inspired inductive bias that captures key structural properties of multi-step refinement in a parallelizable form. We employ a Write-Forget Decoupling strategy that isolates non-linearity within the injection operator. To bypass the serial dependency of explicit solvers, PRISM utilizes a two-stage proxy architecture: a short-convolution anchors the initial residual using local history energy, while a learned predictor estimates the refinement updates directly from the input. This design distills structural patterns associated with iterative correction into a parallelizable feedforward operator. Theoretically, we prove that this formulation achieves Rank- accumulation, structurally expanding the update manifold beyond the single-step Rank- bottleneck. Empirically, it achieves comparable performance to explicit optimization methods while achieving 174x higher throughput.
Paper Structure (48 sections, 3 theorems, 48 equations, 2 figures, 8 tables, 2 algorithms)

This paper contains 48 sections, 3 theorems, 48 equations, 2 figures, 8 tables, 2 algorithms.

Key Result

Lemma 4.1

The forgetting operator in PRISM is parameterized as $\mathbf{A}_t = \mathbf{I} - \beta_t \mathbf{k}_t \mathbf{k}_t^\top$, where $\beta_t \in [0, 1]$ and $\left\lVert\mathbf{k}_t\right\rVert \leq 1$. The eigenvalues of this matrix satisfy:

Figures (2)

  • Figure 1: The PRISM Architecture. The framework operates in two phases to approximate the Ideal Non-Linear Solver within a parallelizable linear recurrence. Phase 1 (Input-Anchored Simulation): A ShortConv anchor captures the local pre-activation proxy ($u_t \approx S_{t-1}k_t$). Parallel predictors (${p}$) generate the Contextual Gain vectors to simulate the derivative of the activation ($\sigma'$), while basis projections ($K$) determine the geometric subspaces. Phase 2 (Iterative Rank Accumulation): An unrolled residual loop constructs the high-rank injection matrix $\mathbf{B}$. Each layer $l$ performs a Greedy Residual Subtraction, adding an orthogonal rank-1 update to the accumulator. This expands the update manifold from Rank-1 to Rank-$L$. State Update: The accumulated high-rank update is injected into a Decoupled Linear Recurrence, where the forgetting operator $\mathbf{A}$ remains state-independent to preserve parallel scan efficiency.
  • Figure 2: Training throughput comparison of 0.13B models on a single H20 GPU.

Theorems & Definitions (7)

  • Definition 3.1: State-Independent Recurrence
  • Lemma 4.1: Bounded Spectrum of Gated DeltaNet
  • proof
  • Theorem 4.2: Logarithmic Worst-Case Stability
  • proof
  • Theorem 4.3: Constant Average-Case Error
  • proof