Table of Contents
Fetching ...

MeSH: Memory-as-State-Highways for Recursive Transformers

Chengting Yu, Xiaobo Shu, Yadao Wang, Yizhen Zhang, Haoyi Wu, Jiaang Li, Rujiao Long, Ziheng Chen, Yuchi Xu, Wenbo Su, Bo Zheng

TL;DR

This work identifies two core bottlenecks hindering parameter-efficient recursive transformers: undifferentiated computation and information overload. It proposes Memory–as–State–Highways (MeSH), a memory-buffered, router-guided framework that externalizes state and dynamically routes information across iterations. Empirical results on Pythia-scale models show MeSH consistently improves recursive models and, at 1.4B, even beats larger non-recursive baselines while using roughly a third fewer non-embedding parameters, demonstrating strong parameter efficiency. The findings establish explicit routed state management as a scalable architectural principle for strengthening recursive models and suggest potential extensions to non-recursive networks.

Abstract

Recursive transformers reuse parameters and iterate over hidden states multiple times, decoupling compute depth from parameter depth. However, under matched compute, recursive models with fewer parameters often lag behind non-recursive counterparts. By probing hidden states, we trace this performance gap to two primary bottlenecks: undifferentiated computation, where the core is forced to adopt a similar computational pattern at every iteration, and information overload, where long-lived and transient information must coexist in a single hidden state. To address the issues, we introduce a Memory-as-State-Highways (MeSH) scheme, which externalizes state management into an explicit memory buffer and employs lightweight routers to dynamically diversify computation across iterations. Probing visualizations confirm that MeSH successfully resolves the pathologies by inducing functional specialization across iterations. On the Pythia suite (160M-1.4B), MeSH-enhanced recursive transformers consistently improve over recursive baselines and outperforms its larger non-recursive counterpart at the 1.4B scale, improving average downstream accuracy by +1.06% with 33% fewer non-embedding parameters. Our analysis establishes MeSH as a scalable and principled architecture for building stronger recursive models.

MeSH: Memory-as-State-Highways for Recursive Transformers

TL;DR

This work identifies two core bottlenecks hindering parameter-efficient recursive transformers: undifferentiated computation and information overload. It proposes Memory–as–State–Highways (MeSH), a memory-buffered, router-guided framework that externalizes state and dynamically routes information across iterations. Empirical results on Pythia-scale models show MeSH consistently improves recursive models and, at 1.4B, even beats larger non-recursive baselines while using roughly a third fewer non-embedding parameters, demonstrating strong parameter efficiency. The findings establish explicit routed state management as a scalable architectural principle for strengthening recursive models and suggest potential extensions to non-recursive networks.

Abstract

Recursive transformers reuse parameters and iterate over hidden states multiple times, decoupling compute depth from parameter depth. However, under matched compute, recursive models with fewer parameters often lag behind non-recursive counterparts. By probing hidden states, we trace this performance gap to two primary bottlenecks: undifferentiated computation, where the core is forced to adopt a similar computational pattern at every iteration, and information overload, where long-lived and transient information must coexist in a single hidden state. To address the issues, we introduce a Memory-as-State-Highways (MeSH) scheme, which externalizes state management into an explicit memory buffer and employs lightweight routers to dynamically diversify computation across iterations. Probing visualizations confirm that MeSH successfully resolves the pathologies by inducing functional specialization across iterations. On the Pythia suite (160M-1.4B), MeSH-enhanced recursive transformers consistently improve over recursive baselines and outperforms its larger non-recursive counterpart at the 1.4B scale, improving average downstream accuracy by +1.06% with 33% fewer non-embedding parameters. Our analysis establishes MeSH as a scalable and principled architecture for building stronger recursive models.

Paper Structure

This paper contains 40 sections, 8 equations, 9 figures, 6 tables, 2 algorithms.

Figures (9)

  • Figure 1: Diagnostic visualizations of the recursive transformer. Analyses are performed on a Pythia-410M-based model with the Prelude-Reccurent-Coda architecture (3 core loops). Hidden state matrices ($\mathbf{h} \in \mathbb{R}^{\text{seq} \times \text{dim}}$) are extracted from 500 samples from the Pile dataset. The states $\mathbf{h}_{\text{emb}}, \mathbf{h}^{(0)} \dots \mathbf{h}_{\text{out}}$ refer to the initial token embeddings, the states to each block, and the final output state. We leave further experimental details and analysis to Section \ref{['sec:analysis']}. (a) Skewed computational pattern. Plots the relative magnitude of the state update, calculated for each computational block ($f$) as $2||f(\mathbf{h}) - \mathbf{h}||_F / (||f(\mathbf{h})||_F + ||\mathbf{h}||_F)$, where $||\cdot||_F$ is the Frobenius norm, which serves as a proxy for the computational effort of each block. Bars show the mean and standard deviation across 500 samples. (b) Representational stagnation. Displays the pairwise Centered Kernel Alignment (CKA) kornblith2019similarity similarity with an RBF kernel between the hidden state matrices. (c) Loop representational collapse. Shows the top 50 normalized singular values ($\sigma_i / \sigma_0$) for each hidden state matrix on a logarithmic Y-axis. The decay rate of the spectrum indicates the effective rank or intrinsic dimensionality of each state matrix.
  • Figure 2: Comparison of recurrence schemes.(a) The general architecture of a recursive transformer involves the general dataflow passing a state $\mathbf{h}^{(t)}$ through a core computational block $f_{\text{core}}$ to produce the next state $\mathbf{h}^{(t+1)}$. (b) Common heuristic variants employ a fixed, additive state update to optimize the information flow, where the core output is supplemented by historical states $\mathbf{h}_{\text{sup}}$ (e.g., initial state $\mathbf{h}^{(0)}$ for anchor or previous state $\mathbf{h}^{(t)}$ for residual). (c) Our proposed MeSH replaces this rigid addition with a dynamic memory mechanism, which explicitly manages historical states via learnable write and read operations, allowing the model to flexibly retrieve and combine information to form the next state $\mathbf{h}^{(t+1)}$.
  • Figure 3: Skewed Computational Pattern. Plots the relative magnitude of the state update, calculated for each computational block ($f$) as $2||f(\mathbf{h}) - \mathbf{h}||_F / (||f(\mathbf{h})||_F + ||\mathbf{h}||_F)$, where $||\cdot||_F$ is the Frobenius norm, which serves as a proxy for the computational effort of each block. Bars show the mean and standard deviation across 500 samples.
  • Figure 4: Representational Stagnation. Displays the pairwise Centered Kernel Alignment (CKA) similarity with an RBF kernel between hidden state matrices ($\mathbf{h} \in \mathbb{R}^{\text{seq} \times \text{dim}}$) at different stages of the model. The matrix shows the mean similarity across 500 samples. High similarity (values near 1.0) between consecutive loop states indicates that representations have stopped evolving.
  • Figure 5: Loop Representational Collapse. Shows the top 50 normalized singular values ($\sigma_i / \sigma_0$) for key hidden state matrices on a logarithmic Y-axis. The decay rate of the spectrum indicates the effective rank or intrinsic dimensionality of each state matrix. A faster decay signifies a collapse into a lower-dimensional representation. Lines and shaded areas represent the mean and standard deviation across 500 samples.
  • ...and 4 more figures