Table of Contents
Fetching ...

Parallel Loop Transformer for Efficient Test-Time Computation Scaling

Bohong Wu, Mengzhao Chen, Xiang Luo, Shen Yan, Qifan Yu, Fan Xia, Tianqi Zhang, Hongrui Zhan, Zheng Zhong, Xun Zhou, Siyuan Qiao, Xingyan Bin

TL;DR

Large Language Models suffer slow, costly inference, and looped transformers achieve parameter-efficient depth but incur per-token latency and KV-cache growth that scale with the number of loops. The Parallel Loop Transformer (PLT) introduces Cross-Loop Parallelism (CLP) to overlap loop computations across tokens and Efficient Representation Enhancement (KV-cache sharing plus gated Sliding-Window Attention, G-SWA) to cap memory, yielding near non-looped latency while preserving deep-loop accuracy. Empirical results show PLT matches or exceeds the accuracy of vanilla looped models with substantially lower latency and reduced KV-cache footprint, and a smaller PLT can outperform a larger vanilla model in both accuracy and efficiency. This enables scalable, fast test-time computation for high-performance LLMs, making looped architectures practical for real-time applications.

Abstract

Large Language Models (LLMs) are powerful but often too slow and costly for real-world use during inference. Looped transformers save on parameters by reusing the same weights for multiple computational steps, or "loops." However, this approach has a major flaw: the loops run one after another, causing inference latency and memory requirements to increase with each added loop. This makes them impractical for fast applications. To solve this problem, we introduce the Parallel Loop Transformer (PLT). PLT is a new architecture that delivers the performance benefits of a deep, looped model but with the low latency of a standard, non-looped model. PLT works using two key techniques. First, Cross-Loop Parallelism (CLP) breaks the sequential dependency by computing different loops for different tokens at the same time, all within a single pass. Second, to prevent memory costs from growing, we use an Efficient Representation Enhancement strategy. This method shares the memory (KV cache) from the first loop with all other loops. It then uses a Gated Sliding-Window Attention (G-SWA) to combine this shared global information with local information, maintaining high accuracy. Our experiments show that PLT achieves the high accuracy of a traditional looped model but with almost no extra latency or memory cost compared to a standard transformer.

Parallel Loop Transformer for Efficient Test-Time Computation Scaling

TL;DR

Large Language Models suffer slow, costly inference, and looped transformers achieve parameter-efficient depth but incur per-token latency and KV-cache growth that scale with the number of loops. The Parallel Loop Transformer (PLT) introduces Cross-Loop Parallelism (CLP) to overlap loop computations across tokens and Efficient Representation Enhancement (KV-cache sharing plus gated Sliding-Window Attention, G-SWA) to cap memory, yielding near non-looped latency while preserving deep-loop accuracy. Empirical results show PLT matches or exceeds the accuracy of vanilla looped models with substantially lower latency and reduced KV-cache footprint, and a smaller PLT can outperform a larger vanilla model in both accuracy and efficiency. This enables scalable, fast test-time computation for high-performance LLMs, making looped architectures practical for real-time applications.

Abstract

Large Language Models (LLMs) are powerful but often too slow and costly for real-world use during inference. Looped transformers save on parameters by reusing the same weights for multiple computational steps, or "loops." However, this approach has a major flaw: the loops run one after another, causing inference latency and memory requirements to increase with each added loop. This makes them impractical for fast applications. To solve this problem, we introduce the Parallel Loop Transformer (PLT). PLT is a new architecture that delivers the performance benefits of a deep, looped model but with the low latency of a standard, non-looped model. PLT works using two key techniques. First, Cross-Loop Parallelism (CLP) breaks the sequential dependency by computing different loops for different tokens at the same time, all within a single pass. Second, to prevent memory costs from growing, we use an Efficient Representation Enhancement strategy. This method shares the memory (KV cache) from the first loop with all other loops. It then uses a Gated Sliding-Window Attention (G-SWA) to combine this shared global information with local information, maintaining high accuracy. Our experiments show that PLT achieves the high accuracy of a traditional looped model but with almost no extra latency or memory cost compared to a standard transformer.

Paper Structure

This paper contains 37 sections, 2 equations, 4 figures, 6 tables, 3 algorithms.

Figures (4)

  • Figure 1: Illustration of the computation flow. (a) Vanilla loop transformer, where each loop in each token should be computed in serial manner. (b) Parallel loop transformer (PLT), where transformer loops within the same blue dashed box can be computed in parallel.
  • Figure 2: Training and inference pipeline of PLT with loop count $L{=}3$. Training (Left): Same Colored boxes trace how input tokens traverse the loops to predict their targets (e.g., token $T_1$ passes three loops to predict $T_4$, consistent with Figure \ref{['fig:our_infer']}). Training is parallel along the token dimension and serial along the loop dimension. Inference (Right): Parallelized forward pass of PLT when decoding $T_4$ and $T_5$ in a Loop Transformer with $L{=}3$. Because there are no horizontal (same-step, cross-loop) activation dependencies during training, computations within the same step (each row; see the blue dashed box) run in parallel during decoding.
  • Figure 3: Batch size vs latency on Seed-MoE (2.5B/60B) and PLT-2 (1.7B/40B).
  • Figure 4: Inference efficiency analysis including latency and throughput for vanilla transformer, PLT and looped transformer over 1 billion activated parameters. We use FP8 quantization during inference based on VLLM Kwon2023vllm.