GatedFWA: Linear Flash Windowed Attention with Gated Associative Memory
Jiaxu Liu, Yuhe Bai, Christos-Savvas Bouganis
TL;DR
The paper tackles the quadratic cost of Softmax attention and the gradient instabilities inherent in sliding-window attention by proposing GatedFWA, a memory-gated linear-time attention mechanism that injects a learnable contraction into the memory recurrence. By reframing attention as associative memory, it analyzes why Softmax suffers gradient vanishing and SWA can exhibit unbounded updates, and then introduces a data-dependent gate that stabilizes updates while preserving linear complexity and hardware efficiency through a FlashAttention-compatible kernel. It provides a hardware-aligned two-phase design (gate preprocessing and gated FA kernel) and demonstrates compatibility with token compression/selection (NSA), achieving strong performance on long-context language modeling and recall-heavy tasks with minimal preprocessing overhead. The work shows that GatedFWA offers efficient, stable long-range context processing suitable for deployment in NSA pipelines, while outlining future directions to extend expressivity beyond TC0 using non-commutative updates."
Abstract
Modern autoregressive models rely on attention, yet the Softmax full attention in Transformers scales quadratically with sequence length. Sliding Window Attention (SWA) achieves linear-time encoding/decoding by constraining the attention pattern, but under an \textit{Associative Memory} interpretation, its difference-style update renders the training objective effectively \emph{unbounded}. In contrast, Softmax attention normalizes updates, leading to \emph{memory shrinkage and gradient vanishing}. We propose GatedFWA: a Memory-\underline{Gated} (\underline{F}lash) \underline{W}indowed \underline{A}ttention mechanism that preserves SWAs efficiency while stabilizing memory updates and making gradient flow controllable. In essence, GatedFWA accumulate a per-token/head gate into a decay bias added to the attention logits, acting as a learnable contraction in the memory recurrence. We implement a fused one-pass gate preprocessing and a FlashAttention-compatible kernel that injects the gate under a sliding mask, ensuring I/O efficiency and numerical stability. On language modelling benchmarks, GatedFWA delivers competitive throughput with negligible overhead and better use of global context, and it integrates cleanly with token compression/selection methods such as NSA and generalizes to various autoregressive domains.
