Table of Contents
Fetching ...

Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation

Sangmin Bae, Yujin Kim, Reza Bayat, Sungnyun Kim, Jiyoun Ha, Tal Schuster, Adam Fisch, Hrayr Harutyunyan, Ziwei Ji, Aaron Courville, Se-Young Yun

TL;DR

MoR introduces a unified Transformer framework that simultaneously achieves parameter efficiency and adaptive token-level computation by sharing a single recursion block and routing tokens to variable recursion depths. It couples lightweight routing (expert-choice or token-choice) with KV caching strategies (recursion-wise or recursive sharing) to dramatically reduce compute and memory traffic while preserving or improving model quality. Across 135M–1.7B parameter scales, MoR establishes a new Pareto frontier, delivering lower perplexity and better few-shot accuracy under equal compute, and higher inference throughput through depth-wise batching. The results indicate MoR as a viable path to large-model performance with substantially reduced training and inference costs, with ample room for future scaling, adaptive capacity control, and multimodal extensions.

Abstract

Scaling language models unlocks impressive capabilities, but the accompanying computational and memory demands make both training and deployment expensive. Existing efficiency efforts typically target either parameter sharing or adaptive computation, leaving open the question of how to attain both simultaneously. We introduce Mixture-of-Recursions (MoR), a unified framework that combines the two axes of efficiency inside a single Recursive Transformer. MoR reuses a shared stack of layers across recursion steps to achieve parameter efficiency, while lightweight routers enable adaptive token-level thinking by dynamically assigning different recursion depths to individual tokens. This allows MoR to focus quadratic attention computation only among tokens still active at a given recursion depth, further improving memory access efficiency by selectively caching only their key-value pairs. Beyond these core mechanisms, we also propose a KV sharing variant that reuses KV pairs from the first recursion, specifically designed to further decrease memory footprint. Across model scales ranging from 135M to 1.7B parameters, MoR forms a new Pareto frontier: at equal training FLOPs and smaller model sizes, it significantly lowers validation perplexity and improves few-shot accuracy, while delivering higher throughput compared with vanilla and existing recursive baselines. These gains demonstrate that MoR is an effective path towards large-model quality without incurring large-model cost.

Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation

TL;DR

MoR introduces a unified Transformer framework that simultaneously achieves parameter efficiency and adaptive token-level computation by sharing a single recursion block and routing tokens to variable recursion depths. It couples lightweight routing (expert-choice or token-choice) with KV caching strategies (recursion-wise or recursive sharing) to dramatically reduce compute and memory traffic while preserving or improving model quality. Across 135M–1.7B parameter scales, MoR establishes a new Pareto frontier, delivering lower perplexity and better few-shot accuracy under equal compute, and higher inference throughput through depth-wise batching. The results indicate MoR as a viable path to large-model performance with substantially reduced training and inference costs, with ample room for future scaling, adaptive capacity control, and multimodal extensions.

Abstract

Scaling language models unlocks impressive capabilities, but the accompanying computational and memory demands make both training and deployment expensive. Existing efficiency efforts typically target either parameter sharing or adaptive computation, leaving open the question of how to attain both simultaneously. We introduce Mixture-of-Recursions (MoR), a unified framework that combines the two axes of efficiency inside a single Recursive Transformer. MoR reuses a shared stack of layers across recursion steps to achieve parameter efficiency, while lightweight routers enable adaptive token-level thinking by dynamically assigning different recursion depths to individual tokens. This allows MoR to focus quadratic attention computation only among tokens still active at a given recursion depth, further improving memory access efficiency by selectively caching only their key-value pairs. Beyond these core mechanisms, we also propose a KV sharing variant that reuses KV pairs from the first recursion, specifically designed to further decrease memory footprint. Across model scales ranging from 135M to 1.7B parameters, MoR forms a new Pareto frontier: at equal training FLOPs and smaller model sizes, it significantly lowers validation perplexity and improves few-shot accuracy, while delivering higher throughput compared with vanilla and existing recursive baselines. These gains demonstrate that MoR is an effective path towards large-model quality without incurring large-model cost.

Paper Structure

This paper contains 95 sections, 8 equations, 9 figures, 17 tables.

Figures (9)

  • Figure 1: Overview of Mixture-of-Recursions (MoR). (Left) Each recursion step consists of a fixed stack of layers and a router that determines whether each token should pass through or exit. This recursion block corresponds to the gray box in the middle. (Middle) The full model structure, where the shared recursion step is applied up to $N_r$ times for each token depending on the router decision. (Right) An example routing pattern showing token-wise recursion depth, where darker cells indicate active computation through the recursion block. Below shows the number of recursion steps that each subword token undergoes to predict the next token, shown in colors: 1, 2, and 3.
  • Figure 2: Architectural components of Mixture-of-Recursions (MoR). (a) Expert-choice routing: At each recursion step, a router selects top-$k$ tokens to continue, progressively narrowing the set of active tokens with depth. (b) Token-choice routing: Each token is assigned a fixed recursion step at the outset via a single routing decision, defining its complete compute path through the model. (c) KV caching strategies: Each square in the matrix represents whether a token (row) attends to another token’s cached key (column). In "recursion-wise KV caching" (Top), only the keys of currently selected (non-dropped) tokens at each recursion step are cached (blue), and attention is restricted only to these entries. In "recursive KV sharing" (Bottom), all keys of previous tokens are cached at the first recursion step (purple) and shared across subsequent recursion steps for attention operations.
  • Figure 3: Validation loss across different compute budgets across four model sizes: 135M, 360M, 730M, and 1.7B parameters. For MoR models, we use expert-choice routing and recursion-wise caching. MoR consistently outperforms recursive baselines and matches or exceeds the standard Transformers at larger scales, despite using significantly fewer parameters (approximately one-third due to layer tying with $N_R=3$).
  • Figure 4: (a) Pareto frontier of inference throughput and log-likehood for MoR and Vanilla Transformer under fixed and maximum batching scenarios. Setting details are in \ref{['app:throughput']}. (b) Negative log-likelihood (NLL) of Recursive Transformers with $N_r=3$ across four different parameter-sharing strategies. We pretrained the models on 10 billion tokens. The dashed red and black lines denote the full-size Vanilla Transformer and parameter-matched vanilla models (approximately one-third scales), respectively. (c) NLL performance comparison across four different architectures with KV sharing. For MoR, green (disabled) and blue (enabled) refer to recursion-wise KV caching and recursive KV sharing strategies. MoR-E and MoR-T denotes expert-choice and token-choice MoR, respectively. All models are based on 360M scale and trained on 10 billion tokens.
  • Figure 5: (a) Compute-optimal scaling analysis for three model architectures. Each star indicates the optimal model size for a given compute budget. We visualize the results in §\ref{['subsec:scaling_laws']} by fitting polynomial functions for each architecture and FLOPs budget, and derive the optimal points from these fits. (b) Distribution of router outputs for selected and unselected tokens at each recursion step. As an example, a 360M size-based MoR model with $N_r=3$, expert-choice router with auxiliary loss, and recursion-wise caching, is used. (c) Test-time scaling analysis illustrating the cumulative log-likelihood improvement with increasing recursion depth, measured over 500 samples. As we increase $N_r$ based on a 360M model size, the number of unique parameters in MoR decreases, resulting in a gradual decline in overall performance (i.e., a decrease in log-likelihood). All models are trained by an expert-choice router with auxiliary loss and a recursion-wise caching mechanism.
  • ...and 4 more figures