Table of Contents
Fetching ...

Reinforced Fast Weights with Next-Sequence Prediction

Hee Seung Hwang, Xindi Wu, Sanghyuk Chun, Olga Russakovsky

TL;DR

ReFINE reframes fast weight language modeling by replacing next-token prediction with next-sequence prediction, trained via reinforcement learning to optimize multi-token continuations at high-uncertainty positions. The framework uses entropy-based token selection, rollout generation, and cosine-based or hybrid rewards implemented through Group Relative Policy Optimization to produce sequence-level supervision for long-context memory. Across mid-training, post-training, and test-time training, ReFINE consistently improves long-context retrieval, multi-document QA, and LongBench tasks on LaCT-760M and DeltaNet-1.3B, outperforming standard NTP-based fine-tuning while maintaining performance on short-context tasks. This approach offers a flexible, phase-agnostic pathway to enhance long-context modeling in fast weight architectures with practical implications for scalable, long-horizon reasoning.

Abstract

Fast weight architectures offer a promising alternative to attention-based transformers for long-context modeling by maintaining constant memory overhead regardless of context length. However, their potential is limited by the next-token prediction (NTP) training paradigm. NTP optimizes single-token predictions and ignores semantic coherence across multiple tokens following a prefix. Consequently, fast weight models, which dynamically update their parameters to store contextual information, learn suboptimal representations that fail to capture long-range dependencies. We introduce REFINE (Reinforced Fast weIghts with Next sEquence prediction), a reinforcement learning framework that trains fast weight models under the next-sequence prediction (NSP) objective. REFINE selects informative token positions based on prediction entropy, generates multi-token rollouts, assigns self-supervised sequence-level rewards, and optimizes the model with group relative policy optimization (GRPO). REFINE is applicable throughout the training lifecycle of pre-trained language models: mid-training, post-training, and test-time training. Our experiments on LaCT-760M and DeltaNet-1.3B demonstrate that REFINE consistently outperforms supervised fine-tuning with NTP across needle-in-a-haystack retrieval, long-context question answering, and diverse tasks in LongBench. REFINE provides an effective and versatile framework for improving long-context modeling in fast weight architectures.

Reinforced Fast Weights with Next-Sequence Prediction

TL;DR

ReFINE reframes fast weight language modeling by replacing next-token prediction with next-sequence prediction, trained via reinforcement learning to optimize multi-token continuations at high-uncertainty positions. The framework uses entropy-based token selection, rollout generation, and cosine-based or hybrid rewards implemented through Group Relative Policy Optimization to produce sequence-level supervision for long-context memory. Across mid-training, post-training, and test-time training, ReFINE consistently improves long-context retrieval, multi-document QA, and LongBench tasks on LaCT-760M and DeltaNet-1.3B, outperforming standard NTP-based fine-tuning while maintaining performance on short-context tasks. This approach offers a flexible, phase-agnostic pathway to enhance long-context modeling in fast weight architectures with practical implications for scalable, long-horizon reasoning.

Abstract

Fast weight architectures offer a promising alternative to attention-based transformers for long-context modeling by maintaining constant memory overhead regardless of context length. However, their potential is limited by the next-token prediction (NTP) training paradigm. NTP optimizes single-token predictions and ignores semantic coherence across multiple tokens following a prefix. Consequently, fast weight models, which dynamically update their parameters to store contextual information, learn suboptimal representations that fail to capture long-range dependencies. We introduce REFINE (Reinforced Fast weIghts with Next sEquence prediction), a reinforcement learning framework that trains fast weight models under the next-sequence prediction (NSP) objective. REFINE selects informative token positions based on prediction entropy, generates multi-token rollouts, assigns self-supervised sequence-level rewards, and optimizes the model with group relative policy optimization (GRPO). REFINE is applicable throughout the training lifecycle of pre-trained language models: mid-training, post-training, and test-time training. Our experiments on LaCT-760M and DeltaNet-1.3B demonstrate that REFINE consistently outperforms supervised fine-tuning with NTP across needle-in-a-haystack retrieval, long-context question answering, and diverse tasks in LongBench. REFINE provides an effective and versatile framework for improving long-context modeling in fast weight architectures.
Paper Structure (52 sections, 12 equations, 8 figures, 14 tables)

This paper contains 52 sections, 12 equations, 8 figures, 14 tables.

Figures (8)

  • Figure 1: Comparison of standard NTP and ReFINE. Standard NTP (top) computes cross-entropy loss at each token position, providing only token-level supervision to fast weight models. ReFINE (bottom) provides sequence-level supervision by generating multi-token rollouts at high-entropy positions, assigning sequence-level rewards from hidden states, and optimizing with RL.
  • Figure 2: Comparison of Standard Transformer and Fast Weight Models, adapted from zhang2025test. Fast weight models replace attention with a fixed-size memory implemented as a weight matrix ($W$), and updated according to \ref{['eq:update']}.
  • Figure 3: ReFINE. We forward the sequence through the policy model and compute token-level entropy values. Sequences are split into chunks and a target token position is sampled from each chunk based on the entropy (Entropy-Based Token Selection). Prefixes are copied from the original sequence up to each target token. The policy model predicts continuations from the prefixes (Rollout Generation). Reward is computed based on the generated rollouts and ground truth tokens (Reward Assignment). Finally, we update the policy model with GRPO (Optimization with RL).
  • Figure 4: NTP Accuracy on Booksum.ReFINE mid-training on DeltaNet-1.3B (a) and LaCT-760M (b) leads to a consistent increase in NTP accuracy on the validation dataset while that of SFT mid-training is stagnant. The error bars show the minimum and maximum values from three independent trials.
  • Figure 5: Ablation on $k$ and $c$. We mid-train models with different numbers of tokens per rollout $k$ (left) and numbers of chunks per sequence $c$ (right). We evaluate on 16K-context samples from 12 tasks in LongBench bai2024longbench. With cosine similarity reward, there is an optimal $k$. Higher $c$ leads to more NSP training per sequence, which leads to better overall performance.
  • ...and 3 more figures