Table of Contents
Fetching ...

Birdie: Advancing State Space Models with Reward-Driven Objectives and Curricula

Sam Blouir, Jimmy T. H. Smith, Antonios Anastasopoulos, Amarda Shehu

TL;DR

Birdie introduces a training paradigm that unlocks long-context retrieval in efficient state-space models by combining bidirectional input processing with dynamically scheduled pre-training objectives, guided by reinforcement learning. It preserves the SSM architecture while enabling markedly improved in-context recall on tasks like multi-number phonebook lookups and long paragraph question answering, narrowing the gap to Transformers. The approach includes a novel bidirectional SSM design, a suite of specialized objectives (e.g., Selective Copying, Deshuffling, Infilling), and an RL-based scheduler to adapt objective mixtures during training. Across pre-training, instruction tuning, and evaluation on retrieval-heavy benchmarks, Birdie demonstrates notable gains over Next Token Prediction baselines and competitive results on standard benchmarks, highlighting training dynamics as a critical lever for fixed-state, efficient models. This work proposes a practical pathway to deploy efficient SSMs in long-context NLP while outlining limitations and avenues for further improvement.

Abstract

Efficient state space models (SSMs), such as linear recurrent neural networks and linear attention variants, offer computational advantages over Transformers but struggle with tasks requiring long-range in-context retrieval-like text copying, associative recall, and question answering over long contexts. Previous efforts to address these challenges have focused on architectural modifications, often reintroducing computational inefficiencies. In this paper, we propose a novel training procedure, Birdie, that significantly enhances the in-context retrieval capabilities of SSMs without altering their architecture. Our approach combines bidirectional input processing with dynamic mixtures of specialized pre-training objectives, optimized via reinforcement learning. We introduce a new bidirectional SSM architecture that seamlessly transitions from bidirectional context processing to causal generation. Experimental evaluations demonstrate that Birdie markedly improves performance on retrieval-intensive tasks such as multi-number phone book lookup, long paragraph question-answering, and infilling. This narrows the performance gap with Transformers, while retaining computational efficiency. Our findings highlight the importance of training procedures in leveraging the fixed-state capacity of SSMs, offering a new direction to advance their capabilities. All code and pre-trained models are available at https://www.github.com/samblouir/birdie, with support for JAX and PyTorch.

Birdie: Advancing State Space Models with Reward-Driven Objectives and Curricula

TL;DR

Birdie introduces a training paradigm that unlocks long-context retrieval in efficient state-space models by combining bidirectional input processing with dynamically scheduled pre-training objectives, guided by reinforcement learning. It preserves the SSM architecture while enabling markedly improved in-context recall on tasks like multi-number phonebook lookups and long paragraph question answering, narrowing the gap to Transformers. The approach includes a novel bidirectional SSM design, a suite of specialized objectives (e.g., Selective Copying, Deshuffling, Infilling), and an RL-based scheduler to adapt objective mixtures during training. Across pre-training, instruction tuning, and evaluation on retrieval-heavy benchmarks, Birdie demonstrates notable gains over Next Token Prediction baselines and competitive results on standard benchmarks, highlighting training dynamics as a critical lever for fixed-state, efficient models. This work proposes a practical pathway to deploy efficient SSMs in long-context NLP while outlining limitations and avenues for further improvement.

Abstract

Efficient state space models (SSMs), such as linear recurrent neural networks and linear attention variants, offer computational advantages over Transformers but struggle with tasks requiring long-range in-context retrieval-like text copying, associative recall, and question answering over long contexts. Previous efforts to address these challenges have focused on architectural modifications, often reintroducing computational inefficiencies. In this paper, we propose a novel training procedure, Birdie, that significantly enhances the in-context retrieval capabilities of SSMs without altering their architecture. Our approach combines bidirectional input processing with dynamic mixtures of specialized pre-training objectives, optimized via reinforcement learning. We introduce a new bidirectional SSM architecture that seamlessly transitions from bidirectional context processing to causal generation. Experimental evaluations demonstrate that Birdie markedly improves performance on retrieval-intensive tasks such as multi-number phone book lookup, long paragraph question-answering, and infilling. This narrows the performance gap with Transformers, while retaining computational efficiency. Our findings highlight the importance of training procedures in leveraging the fixed-state capacity of SSMs, offering a new direction to advance their capabilities. All code and pre-trained models are available at https://www.github.com/samblouir/birdie, with support for JAX and PyTorch.

Paper Structure

This paper contains 65 sections, 23 equations, 9 figures, 13 tables.

Figures (9)

  • Figure 1: The Multi-Phone Number Retrieval Task entails finding and retrieving 1-32 phone numbers over a sequence length of 16,384. We demonstrate that State Space Models (SSMs) trained with Birdie significantly reduce their performance disparity with Transformers. For further details, please see section \ref{['subsec:phonebook']}. (A) We conduct an ablation study comparing Hawk with Birdie, Birdie - Causal, and Next Token Prediction, alongside a Transformer using Birdie and Next Token Prediction. Hawk trained with Birdie and Birdie - Causal demonstrate significantly higher performance than when trained using Next Token Prediction. (B) An ablation that includes UL2 and the Fixed Ratio Mixture on our Gated SSM.
  • Figure 2: Models retrieve text found between special start and end tokens in our self-supervised Selective Copying pre-training task. Please see Section \ref{['main:selective_copying_text']} and Appendix Section \ref{['appendix:selective_copying_example']} for more details. We show model performance on this task in the "Selective Copying" column of Figure \ref{['fig:APPENDIXRL']}.
  • Figure 3: SQuAD V2 Question-Answering results with instruction-tuned models. Training with the Birdie procedure strongly improves SSM performance, compared to Next Token Prediction. Average results are shown in Table \ref{['table:squadv2_main_results']}. Further details and ablations are available in Section \ref{['sec:Question_Answering_SquadV2']} and Appendix Section \ref{['appendix:section:squad_v2']}. (A) Answer Contains Label measures when a label is produced by the model verbatim. (B) The F1 Score awards partial credit for matching words in the label and also penalizes models for generating words not in labels.
  • Figure 4: Pseudocode for sampling an action from Birdie.
  • Figure 5: These plots show how several metrics evolve during training. Loss and Accuracy are on Validation data from The Pile. Accuracy denotes greedy decoding accuracy. Sampling Probability (%) denotes the probability that an objective in a class is selected for each segmented sample from the training dataloader, as selected by Birdie. The parameterizations for each objective are described in Appendix Section \ref{['appendix:birdie_controls_overview']}.
  • ...and 4 more figures