Table of Contents
Fetching ...

RAT+: Train Dense, Infer Sparse -- Recurrence Augmented Attention for Dilated Inference

Xiuying Wei, Caglar Gulcehre

TL;DR

This work introduces RAT+, a dense-pretraining architecture that augments attention with full-sequence recurrence and active recurrence learning, and shows that RAT+ outperforms attention when sparsifying to the top-k block attention.

Abstract

Structured dilated attention has an appealing inference-time efficiency knob: it reduces the FLOPs of the attention and the KV cache size by a factor of the dilation size D, while preserving long-range connectivity. However, we find a persistent failure mode of them -- sparsifying a pretrained attention model to a dilated pattern leads to severe accuracy degradation. We introduce RAT+, a dense-pretraining architecture that augments attention with full-sequence recurrence and active recurrence learning. A single RAT+ model is pretrained densely once, then flexibly switched at inference time to dilated attention (optionally with local windows) or hybrid layer/head compositions, requiring only a short 1B-token resolution adaptation rather than retraining separate sparse models. At 1.5B parameters trained on 100B tokens, RAT+ closely matches dense accuracy at 16 and drops by about 2-3 points at 64 on commonsense reasoning and LongBench tasks, respectively. Moreover, RAT+ outperforms attention when sparsifying to the top-k block attention. We further scale to 2.6B parameters and 200B tokens and observe the same trend.

RAT+: Train Dense, Infer Sparse -- Recurrence Augmented Attention for Dilated Inference

TL;DR

This work introduces RAT+, a dense-pretraining architecture that augments attention with full-sequence recurrence and active recurrence learning, and shows that RAT+ outperforms attention when sparsifying to the top-k block attention.

Abstract

Structured dilated attention has an appealing inference-time efficiency knob: it reduces the FLOPs of the attention and the KV cache size by a factor of the dilation size D, while preserving long-range connectivity. However, we find a persistent failure mode of them -- sparsifying a pretrained attention model to a dilated pattern leads to severe accuracy degradation. We introduce RAT+, a dense-pretraining architecture that augments attention with full-sequence recurrence and active recurrence learning. A single RAT+ model is pretrained densely once, then flexibly switched at inference time to dilated attention (optionally with local windows) or hybrid layer/head compositions, requiring only a short 1B-token resolution adaptation rather than retraining separate sparse models. At 1.5B parameters trained on 100B tokens, RAT+ closely matches dense accuracy at 16 and drops by about 2-3 points at 64 on commonsense reasoning and LongBench tasks, respectively. Moreover, RAT+ outperforms attention when sparsifying to the top-k block attention. We further scale to 2.6B parameters and 200B tokens and observe the same trend.
Paper Structure (32 sections, 2 equations, 8 figures, 17 tables)

This paper contains 32 sections, 2 equations, 8 figures, 17 tables.

Figures (8)

  • Figure 1: (a) For architectural simplicity, we adopt an extreme overlapped setting, i.e., full-sequence recurrence with $\mathtt{L=T}$. (b) Joint training to preserve dense attention capability while enforcing active recurrence learning with desired effective length $\mathtt{L^*}=64$. (c) After pretraining, the resulting model can be efficiently adapted to various sparse inference patterns including effective results on dilated attention (optional local attention), and better performance on top-k block one compared to attention.
  • Figure 2: Results of top-k block attention with block size $\mathtt{D}$ and number of selected blocks $\mathtt{K}$ on the hard NIAH-MK-2 and NIAH-MK-3 tasks from the RULER benchmark with $\mathtt{T=4096}$. RAT+ (no ARL in SFT) means disabling active recurrence learning during the SFT, which further demonstrates the benefit of the recurrence. More results include remaining tasks in \ref{['tab:pretrain_free_niah_topk_quest']}, MoBA-style top-k block attention in \ref{['tab:pretrain_free_niah_topk_moba']}, dilated attention in \ref{['tab:pretrain_free_niah']}
  • Figure 3: Efficiency results of the temporal-mixing operator on a single GH200 GPU, covering both prefilling and decoding scenarios with hidden dimension $\mathtt{H}$. Prefilling latency is measured on sequences of 262K tokens. Decoding latency is measured for 256 or 128 batches of tokens for the two hidden dimensions, respectively; the baseline runs out of memory beyond 32K tokens. We use FlexAttentionflexattention for prefilling (with $\mathtt{D=1}$ reducing to FlashAttention), and FlashAttentionflashattentionfor decoding. Recurrence is implemented by associative scan in PyTorch for prefilling and a simple step update for decoding. Due to space constraints, additional results are reported in \ref{['sec:exp_app']}.
  • Figure 4: Maximum decoding throughput (tokens/sec) of the full 1.5B and 7B models for decoding 1024 tokens, measured at context lengths of 4096 and 16384, corresponding to prefilling lengths of 3072 and 15360 tokens, respectively.
  • Figure 5: Scaling-up experiments. We report validation loss on a held-out 0.5B-token subset of the FineWeb dataset to illustrate the loss gap between dense and sparse configurations.
  • ...and 3 more figures