Table of Contents
Fetching ...

Understanding and Improving Length Generalization in Recurrent Models

Ricardo Buitrago Ruiz, Albert Gu

TL;DR

This work addresses the challenge of length generalization in recurrent models by proposing the unexplored states hypothesis: models fail to generalize when training exposes them to only a subset of the state distributions that arise on long sequences. It introduces Effective Remembrance as a diagnostic to quantify the reliance on early context and demonstrates that simple interventions, notably State Passing and TBTT, can expose models to attainable state distributions with only a small post-training cost. The results show that with as little as ~500 post-training steps (≈0.1% of pre-training budget), models can generalize from 2k to 128k contexts and perform well on long-context tasks like BABILong, passkey retrieval, and synthetic copying. This points to a practical and architecture-agnostic path to robust length generalization in recurrent models, enabling fairer comparisons across newer recurrent architectures.

Abstract

Recently, recurrent models such as state space models and linear attention have become popular due to their linear complexity in the sequence length. Thanks to their recurrent nature, in principle they can process arbitrarily long sequences, but their performance sometimes drops considerably beyond their training context lengths-i.e. they fail to length generalize. In this work, we provide comprehensive empirical and theoretical analysis to support the unexplored states hypothesis, which posits that models fail to length generalize when during training they are only exposed to a limited subset of the distribution of all attainable states (i.e. states that would be attained if the recurrence was applied to long sequences). Furthermore, we investigate simple training interventions that aim to increase the coverage of the states that the model is trained on, e.g. by initializing the state with Gaussian noise or with the final state of a different input sequence. With only 500 post-training steps ($\sim 0.1\%$ of the pre-training budget), these interventions enable length generalization for sequences that are orders of magnitude longer than the training context (e.g. $2k\longrightarrow 128k$) and show improved performance in long context tasks, thus presenting a simple and efficient way to enable robust length generalization in general recurrent models.

Understanding and Improving Length Generalization in Recurrent Models

TL;DR

This work addresses the challenge of length generalization in recurrent models by proposing the unexplored states hypothesis: models fail to generalize when training exposes them to only a subset of the state distributions that arise on long sequences. It introduces Effective Remembrance as a diagnostic to quantify the reliance on early context and demonstrates that simple interventions, notably State Passing and TBTT, can expose models to attainable state distributions with only a small post-training cost. The results show that with as little as ~500 post-training steps (≈0.1% of pre-training budget), models can generalize from 2k to 128k contexts and perform well on long-context tasks like BABILong, passkey retrieval, and synthetic copying. This points to a practical and architecture-agnostic path to robust length generalization in recurrent models, enabling fairer comparisons across newer recurrent architectures.

Abstract

Recently, recurrent models such as state space models and linear attention have become popular due to their linear complexity in the sequence length. Thanks to their recurrent nature, in principle they can process arbitrarily long sequences, but their performance sometimes drops considerably beyond their training context lengths-i.e. they fail to length generalize. In this work, we provide comprehensive empirical and theoretical analysis to support the unexplored states hypothesis, which posits that models fail to length generalize when during training they are only exposed to a limited subset of the distribution of all attainable states (i.e. states that would be attained if the recurrence was applied to long sequences). Furthermore, we investigate simple training interventions that aim to increase the coverage of the states that the model is trained on, e.g. by initializing the state with Gaussian noise or with the final state of a different input sequence. With only 500 post-training steps ( of the pre-training budget), these interventions enable length generalization for sequences that are orders of magnitude longer than the training context (e.g. ) and show improved performance in long context tasks, thus presenting a simple and efficient way to enable robust length generalization in general recurrent models.

Paper Structure

This paper contains 36 sections, 8 equations, 14 figures, 2 tables.

Figures (14)

  • Figure 1: (Top) Perplexity as a function of token position on the Pile validation dataset pile for the official Mamba-1 and Mamba-2 checkpoints trained with context $T=2048$, as well as for Gated Linear Attention (GLA) models trained with context $T=512$. In dashed lines, we show the same models post-trained with State Passing (SP), which is an intervention that initializes the state with the final state of a different sequence (see Section \ref{['sec:state_passing']}). State Passing is a simple technique that enables length generalization across several recurrent architectures. Mamba-2 and GLA are post-trained for 500 steps and Mamba-1 is post-trained for 1000 steps. A similar plot for RWKV-v6 rwkv-v6-peng2024eaglefinchrwkvmatrixvalued is shown in Figure \ref{['fig:rwkv-v6']}. (Bottom) Effective Remembrance for recurrent models and their State Passing post-trained counterparts. Effective Remembrance at time $t$ roughly measures the impact of the tokens at positions $[0,t)$ on the output of the model at a later position $T$, with 0 indicating no impact (no "effective remembrance" of tokens $[0,t)$) and 1 indicating maximal impact (see Section \ref{['sec:effective_remembrance']} for a precise definition). The baseline models are disproportionately affected by tokens that are very far away in the past, indicating that they are not correctly handling the recent context. State Passing fixes this behavior.
  • Figure 2: Position-wise perplexities for Mamba-2 mamba2 and GLA gated_linear_attention_yang2024gla trained from scratch with different context lengths $T$ on the Pile pile. The longer the training context, the better the length generalization. The 45m model is trained for 22.5B tokens (25x Chinchilla laws), the 70m and 85m are trained for 34B tokens (20x Chinchilla Laws), and the 125 model is trained for 25B tokens (10x Chinchilla Laws).
  • Figure 3: Position-wise perplexities of a Mamba-2 85m model trained for different number of tokens with a context of 1024 on the Pile pile. 2.5x means that the model is trained for 2.5 times what Chinchilla scaling laws dictate for that model size. Thus, the checkpoints correspond to a range between 4.25B and 34B tokens. In this case, the failure to length generalize occurs after training for more than 7.5x Chinchilla laws.
  • Figure 4: Norm of the full state of the Mamba-2 130m official checkpoints versus the sequence position $t$ ($h_t$ in the notation of Section \ref{['sec:ssm_preliminaries']}). The norm is taken across all elements of the state in all layers. The Mamba-2 130m post-trained with State Passing (Section \ref{['sec:state_passing']}) produces states whose standard deviation do not significantly change after the training context $T=2048$. In contrast, the official Mamba-2 checkpoint reaches a standard deviation almost twice as large at position $t=8192$ than the one at position $t=2048$.
  • Figure 5: Position-wise perplexity of official Mamba-2 models (Base) and our four interventions that are applied to these models with 100 post-training steps. For the State Passing and TBTT interventions, 100 steps is enough to enable length generalization in 32k length sequences for all models. The interventions modify the initial state of the recurrent models, thus facilitating the exploration of a wider range of states. They sample an initial state from distributions that progressively get closer to the true distribution of attainable states: (1) Random Noise samples from a Gaussian distribution with fixed variance; (2) Fitted Noise samples from a Gaussian distribution with mean and variance calibrated to the final states seen during training; (3) State Passing uses the final state of a different sequence as initial state; and (4) TBTT splits a sequence into several chunks and uses the final state of the previous chunk as initial state. State Passing directly samples from the distribution of attainable states and together with TBTT has the best performance, supporting the unexplored states hypothesis. For State Passing, the results for Mamba-1 and Gated Linear Attention are also shown in Figure \ref{['fig:pos_ppl+effrem']}.
  • ...and 9 more figures

Theorems & Definitions (3)

  • Definition 3.1: Position-wise Perplexity
  • Definition 3.2: Length Generalization
  • Definition 3.3: Effective Remembrance