Table of Contents
Fetching ...

Learning When to Attend: Conditional Memory Access for Long-Context LLMs

Sakshi Choudhary, Aditya Chattopadhyay, Luca Zancato, Elvis Nunez, Matthew Trager, Wei Xia, Stefano Soatto

Abstract

Language models struggle to generalize beyond pretraining context lengths, limiting long-horizon reasoning and retrieval. Continued pretraining on long-context data can help but is expensive due to the quadratic scaling of Attention. We observe that most tokens do not require (Global) Attention over the entire sequence and can rely on local context. Based on this, we propose L2A (Learning To Attend), a layer that enables conditional (token-wise) long-range memory access by deciding when to invoke global attention. We evaluate L2A on Qwen 2.5 and Qwen 3 models, extending their effective context length from 32K to 128K tokens. L2A matches the performance of standard long-context training to within 3% while skipping Global Attention for $\sim$80% of tokens, outperforming prior baselines. We also design custom Triton kernels to efficiently implement this token-wise conditional Attention on GPUs, achieving up to $\sim$2x improvements in training throughput and time-to-first-token over FlashAttention. Moreover, L2A enables post-training pruning of highly sparse Global Attention layers, reducing KV cache memory by up to 50% with negligible performance loss.

Learning When to Attend: Conditional Memory Access for Long-Context LLMs

Abstract

Language models struggle to generalize beyond pretraining context lengths, limiting long-horizon reasoning and retrieval. Continued pretraining on long-context data can help but is expensive due to the quadratic scaling of Attention. We observe that most tokens do not require (Global) Attention over the entire sequence and can rely on local context. Based on this, we propose L2A (Learning To Attend), a layer that enables conditional (token-wise) long-range memory access by deciding when to invoke global attention. We evaluate L2A on Qwen 2.5 and Qwen 3 models, extending their effective context length from 32K to 128K tokens. L2A matches the performance of standard long-context training to within 3% while skipping Global Attention for 80% of tokens, outperforming prior baselines. We also design custom Triton kernels to efficiently implement this token-wise conditional Attention on GPUs, achieving up to 2x improvements in training throughput and time-to-first-token over FlashAttention. Moreover, L2A enables post-training pruning of highly sparse Global Attention layers, reducing KV cache memory by up to 50% with negligible performance loss.
Paper Structure (30 sections, 18 equations, 18 figures, 6 tables)

This paper contains 30 sections, 18 equations, 18 figures, 6 tables.

Figures (18)

  • Figure 1: Overview of L2A, a sequence modeling layer with token-wise conditional routing. All tokens are first processed by Local Attention. Next, tokens with routing decision $\mathbf{d}_t=1$ invoke Global Attention, whereas tokens with $\mathbf{d}_t=0$ bypass it. This enables efficient long-context modeling by invoking expensive Global Attention only when needed. We refer to the combined {Local Attention + L2A Block} as the L2A layer.
  • Figure 2: Long-context performance and efficiency of L2A at 128K context length.Left: Performance comparison of L2A against CLP for base model and other baselines on Qwen 2.5 7B, where L2A outperforms existing approaches. Right: Kernel-level speedups achieved by L2A relative to FlashAttention-2, with $\sim$10$\times$ and $\sim$8$\times$ speedups for the forward and backward passes at upto $90\%$ sparsity in practice.
  • Figure 3: Task-wise performance comparison of L2A and baseline methods on Qwen 2.5 7B model. L2A remains close to CLP performance while outperforming baselines on the majority of tasks.
  • Figure 4: Aggregate long-context performance of Qwen models at 1.5B and 7B scales. L2A achieves comparable performance to CLP across several context lengths, while significantly outperforming prior baselines.
  • Figure 5: Time-to-first-token (TTFT) for Qwen 2.5 7B averaged over different tasks from our long-context benchmark suites. L2A achieves similar aggregate long-context performance as CLP, while being $\sim2\times$ faster.
  • ...and 13 more figures