PRISM: Parallel Residual Iterative Sequence Model
Jie Jiang, Ke Cheng, Xin Xu, Mengyang Pang, Tianhao Lu, Jiaheng Li, Yue Liu, Yuan Wang, Jun Zhang, Huan Yu, Zhouchen Lin
TL;DR
PRISM addresses the expressivity-efficiency trade-off in generative sequence modeling by substituting serial, multi-step optimization with an amortized, input-anchored refinement that yields Rank-$L$ updates within a parallelizable linear recurrence. It introduces Write-Forget Decoupling to keep forgetting linear and injection high-rank, and employs an Input-Anchored ShortConv proxy plus a learned predictor to synthesize multi-step refinements in parallel. Theoretical analysis shows Rank Accumulation and robust stability of the forgetting path, while experiments demonstrate competitive accuracy to explicit solvers and substantial throughput gains (up to 174x) on long-horizon recommendation benchmarks. These results indicate that distilling iterative optimization trajectories into parallelizable, input-grounded operators is a promising approach for scalable, high-fidelity foundation models.
Abstract
Generative sequence modeling faces a fundamental tension between the expressivity of Transformers and the efficiency of linear sequence models. Existing efficient architectures are theoretically bounded by shallow, single-step linear updates, while powerful iterative methods like Test-Time Training (TTT) break hardware parallelism due to state-dependent gradients. We propose PRISM (Parallel Residual Iterative Sequence Model) to resolve this tension. PRISM introduces a solver-inspired inductive bias that captures key structural properties of multi-step refinement in a parallelizable form. We employ a Write-Forget Decoupling strategy that isolates non-linearity within the injection operator. To bypass the serial dependency of explicit solvers, PRISM utilizes a two-stage proxy architecture: a short-convolution anchors the initial residual using local history energy, while a learned predictor estimates the refinement updates directly from the input. This design distills structural patterns associated with iterative correction into a parallelizable feedforward operator. Theoretically, we prove that this formulation achieves Rank-$L$ accumulation, structurally expanding the update manifold beyond the single-step Rank-$1$ bottleneck. Empirically, it achieves comparable performance to explicit optimization methods while achieving 174x higher throughput.
