Table of Contents
Fetching ...

SpiralFormer: Looped Transformers Can Learn Hierarchical Dependencies via Multi-Resolution Recursion

Chengting Yu, Xiaobo Shu, Yadao Wang, Yizhen Zhang, Haoyi Wu, You Wu, Rujiao Long, Ziheng Chen, Yuchi Xu, Wenbo Su, Bo Zheng

TL;DR

This paper proposes SpiralFormer, a looped Transformer that executes recurrence under a multi-resolution recursion schedule, and provides probing evidence that multi-resolution recursion enables the model to learn hierarchical dependencies by inducing iteration-wise functional specialization across different scales.

Abstract

Recursive (looped) Transformers decouple computational depth from parameter depth by repeatedly applying shared layers, providing an explicit architectural primitive for iterative refinement and latent reasoning. However, early looped Transformers often underperform non-recursive baselines of equal compute. While recent literature has introduced more effective recursion mechanisms to mitigate this gap, existing architectures still operate at a fixed, full-token resolution, neglecting the potential efficiency of computing over compressed latent representations. In this paper, we propose SpiralFormer, a looped Transformer that executes recurrence under a multi-resolution recursion schedule. We provide probing evidence that multi-resolution recursion enables the model to learn hierarchical dependencies by inducing iteration-wise functional specialization across different scales. Empirically, SpiralFormer achieves better parameter and compute efficiency than both looped and non-looped baselines across model scales from 160M to 1.4B, establishing sequence resolution as a potential axis for scaling recursive architectures.

SpiralFormer: Looped Transformers Can Learn Hierarchical Dependencies via Multi-Resolution Recursion

TL;DR

This paper proposes SpiralFormer, a looped Transformer that executes recurrence under a multi-resolution recursion schedule, and provides probing evidence that multi-resolution recursion enables the model to learn hierarchical dependencies by inducing iteration-wise functional specialization across different scales.

Abstract

Recursive (looped) Transformers decouple computational depth from parameter depth by repeatedly applying shared layers, providing an explicit architectural primitive for iterative refinement and latent reasoning. However, early looped Transformers often underperform non-recursive baselines of equal compute. While recent literature has introduced more effective recursion mechanisms to mitigate this gap, existing architectures still operate at a fixed, full-token resolution, neglecting the potential efficiency of computing over compressed latent representations. In this paper, we propose SpiralFormer, a looped Transformer that executes recurrence under a multi-resolution recursion schedule. We provide probing evidence that multi-resolution recursion enables the model to learn hierarchical dependencies by inducing iteration-wise functional specialization across different scales. Empirically, SpiralFormer achieves better parameter and compute efficiency than both looped and non-looped baselines across model scales from 160M to 1.4B, establishing sequence resolution as a potential axis for scaling recursive architectures.
Paper Structure (64 sections, 42 equations, 8 figures, 3 tables, 2 algorithms)

This paper contains 64 sections, 42 equations, 8 figures, 3 tables, 2 algorithms.

Figures (8)

  • Figure 1: SpiralFormer overview.Left: we adopt a Middle-cycle backbone (pre→loop→post) where a shared Transformer core is iterated $T$ times (looped recursion) and combined with the running state via a topology update $\mathcal{U}$. Right: one multi-resolution recursion step at iteration $t$: token-level states $\bm{h}^{(t)}\in\mathbb{R}^{L\times d}$ are causally downsampled into chunk-level latents $\bm{z}^{(t)}\in\mathbb{R}^{L_t\times d}$ (chunk size $g_t = 1/r_t$ with offset $\omega_t$), processed by the shared core to obtain $\widehat{\bm{z}}^{(t)}$, then upsampled back to token-level updates $\bm{u}^{(t)}$. A causal right-shift by $s_t$ produces $\widetilde{\bm{u}}^{(t)}$, ensuring strict autoregressive causality under compression before updating the loop state.
  • Figure 2: Scaling behavior of SpiralFormer.Left: validation loss versus computing FLOPs. Right: downstream 0-shot accuracy versus total parameters.
  • Figure 3: Impact of recurrence ratio on validation loss. We vary the recurrence ratio, defined as the fraction of layers placed in the shared core, $N_{\text{loop}}/N_{\text{total}}$, while keeping the total parameter fixed. Results at 410m scale are shown under two compute budgets (8e19 and 16e19 FLOPs).
  • Figure 4: Cross-loop distribution shifts of attention statistics on dynamic heads. We visualize the distributions of (top) key-marginal entropy and (bottom) Local Attention Mass (LAM) across loop iterations (resolutions) for the dynamic heads (top 40% by cross-loop range for each metric) of the 410M SpiralFormer-B† model. Statistics are computed by averaging each head's response over 500 samples from the Pile validation set. For each loop (from coarse $1/8$ to full resolution $1$), the left subplots show ridgeline histograms of head-wise values, and the right subplots show box plots with the loop-wise mean indicated by a black dot and connected across loops. For completeness, we further include the distributions over all heads in Appendix \ref{['sec:appendix-all-heads']}.
  • Figure 5: Head-wise localization of cross-loop variability. We visualize the cross-loop ranges of (top) key-marginal entropy ($\Delta H$) and (bottom) Local Attention Mass ($\Delta\mathrm{LAM}$) as layer--head heatmaps, where each cell corresponds to one attention head indexed by its layer and head indices $(\ell,h)$. To compare patterns across layers/heads, we normalize each heatmap to $[0,1]$ using min--max normalization over all cells of that heatmap: $x'=(x-\min(x))/(\max(x)-\min(x))$; if $\max(x)-\min(x)$ is numerically negligible (or values are non-finite), we set all entries to $0$. Black boxes mark dynamic heads, defined as the top 40% heads ranked by cross-loop range for the corresponding metric (defined separately per metric). Statistics are computed on 500 sequences sampled from the Pile validation set: for each loop $t$ and head $(\ell,h)$, we first average the metric over the 500 sequences, and then compute cross-loop ranges over $t$.
  • ...and 3 more figures