Table of Contents
Fetching ...

GatedFWA: Linear Flash Windowed Attention with Gated Associative Memory

Jiaxu Liu, Yuhe Bai, Christos-Savvas Bouganis

TL;DR

The paper tackles the quadratic cost of Softmax attention and the gradient instabilities inherent in sliding-window attention by proposing GatedFWA, a memory-gated linear-time attention mechanism that injects a learnable contraction into the memory recurrence. By reframing attention as associative memory, it analyzes why Softmax suffers gradient vanishing and SWA can exhibit unbounded updates, and then introduces a data-dependent gate that stabilizes updates while preserving linear complexity and hardware efficiency through a FlashAttention-compatible kernel. It provides a hardware-aligned two-phase design (gate preprocessing and gated FA kernel) and demonstrates compatibility with token compression/selection (NSA), achieving strong performance on long-context language modeling and recall-heavy tasks with minimal preprocessing overhead. The work shows that GatedFWA offers efficient, stable long-range context processing suitable for deployment in NSA pipelines, while outlining future directions to extend expressivity beyond TC0 using non-commutative updates."

Abstract

Modern autoregressive models rely on attention, yet the Softmax full attention in Transformers scales quadratically with sequence length. Sliding Window Attention (SWA) achieves linear-time encoding/decoding by constraining the attention pattern, but under an \textit{Associative Memory} interpretation, its difference-style update renders the training objective effectively \emph{unbounded}. In contrast, Softmax attention normalizes updates, leading to \emph{memory shrinkage and gradient vanishing}. We propose GatedFWA: a Memory-\underline{Gated} (\underline{F}lash) \underline{W}indowed \underline{A}ttention mechanism that preserves SWAs efficiency while stabilizing memory updates and making gradient flow controllable. In essence, GatedFWA accumulate a per-token/head gate into a decay bias added to the attention logits, acting as a learnable contraction in the memory recurrence. We implement a fused one-pass gate preprocessing and a FlashAttention-compatible kernel that injects the gate under a sliding mask, ensuring I/O efficiency and numerical stability. On language modelling benchmarks, GatedFWA delivers competitive throughput with negligible overhead and better use of global context, and it integrates cleanly with token compression/selection methods such as NSA and generalizes to various autoregressive domains.

GatedFWA: Linear Flash Windowed Attention with Gated Associative Memory

TL;DR

The paper tackles the quadratic cost of Softmax attention and the gradient instabilities inherent in sliding-window attention by proposing GatedFWA, a memory-gated linear-time attention mechanism that injects a learnable contraction into the memory recurrence. By reframing attention as associative memory, it analyzes why Softmax suffers gradient vanishing and SWA can exhibit unbounded updates, and then introduces a data-dependent gate that stabilizes updates while preserving linear complexity and hardware efficiency through a FlashAttention-compatible kernel. It provides a hardware-aligned two-phase design (gate preprocessing and gated FA kernel) and demonstrates compatibility with token compression/selection (NSA), achieving strong performance on long-context language modeling and recall-heavy tasks with minimal preprocessing overhead. The work shows that GatedFWA offers efficient, stable long-range context processing suitable for deployment in NSA pipelines, while outlining future directions to extend expressivity beyond TC0 using non-commutative updates."

Abstract

Modern autoregressive models rely on attention, yet the Softmax full attention in Transformers scales quadratically with sequence length. Sliding Window Attention (SWA) achieves linear-time encoding/decoding by constraining the attention pattern, but under an \textit{Associative Memory} interpretation, its difference-style update renders the training objective effectively \emph{unbounded}. In contrast, Softmax attention normalizes updates, leading to \emph{memory shrinkage and gradient vanishing}. We propose GatedFWA: a Memory-\underline{Gated} (\underline{F}lash) \underline{W}indowed \underline{A}ttention mechanism that preserves SWAs efficiency while stabilizing memory updates and making gradient flow controllable. In essence, GatedFWA accumulate a per-token/head gate into a decay bias added to the attention logits, acting as a learnable contraction in the memory recurrence. We implement a fused one-pass gate preprocessing and a FlashAttention-compatible kernel that injects the gate under a sliding mask, ensuring I/O efficiency and numerical stability. On language modelling benchmarks, GatedFWA delivers competitive throughput with negligible overhead and better use of global context, and it integrates cleanly with token compression/selection methods such as NSA and generalizes to various autoregressive domains.

Paper Structure

This paper contains 50 sections, 3 theorems, 45 equations, 17 figures, 4 tables, 4 algorithms.

Key Result

Theorem 1

Assume a perfectly defined feature map $\phi(\cdot): \mathbb{R}^{d} \to \mathbb{R}^{\mathrm{dim}(\phi)}$, such that $\langle \phi(\mathbf{q}), \phi(\mathbf{k}) \rangle$ approximates $\exp (\frac{\mathbf{q} \mathbf{k}^\top}{\sqrt{d_h}})$ arbitrarily well, then the memory recurrence of Softmax attent and that of SWA ($t>w$) with normalization is formulated by

Figures (17)

  • Figure 1: Memory recurrence interpretation of (a) Softmax: the carried memory is scaled by $\frac{t-1}{t}$ and a $\frac{1}{t}$ new term is added, so normalization steadily shrinks per-step updates and drives gradient vanishing through $\mathbf{M}_t$. (b) SWA: within a width $w$ window the state is non-decaying but updated by a difference term $\phi(\mathbf{k}_t)^\top\mathbf{v}_t-\phi(\mathbf{k}_{t-w})^\top\mathbf{v}_{t-w}$; this implicitly optimizes an unbounded linear objective, can over-amplify memory (unstable gradients). (c) GatedFWA: with non-negative gate accumulates into a decay bias ($\mathbf{B}_{ti}=\sum_{q=i+1}^t -\boldsymbol{\alpha}_q$), yielding a learnable contraction ($\mathbf{M}_t=\exp(-\boldsymbol{\alpha}_t)\mathbf{M}_{t-1}+\cdots$) that softly erases off-path history, bounds the update, and makes gradient flow controllable while retaining SWA’s linear cost. We draw (c) in two steps because its update depends on multiple prior states for stability.
  • Figure 1: Language modelling scaling law against LLaMA(w/ and w/o SWA), RetNet, RWKV, and Mamba. All models are trained on the OpenWebText dataset. Models vary from $120$-$360$M parameters and $1024$-$4096$ context length.
  • Figure 2: Memory hierarchy with bandwidth & memory size.
  • Figure 3: Schematic comparison between (upper) vanilla preprocessing and (lower) our 1-pass fused preprocessing.
  • Figure 4: Schematic comparison between (left) Flash Attention (Dao et. al) and (right) our Hardware-efficient GatedFWA.
  • ...and 12 more figures

Theorems & Definitions (5)

  • Theorem 1: Memory Recurrence of Exact Attention
  • Proposition 1: Optimization Objective of Exact Attention
  • Proposition 2: Memory Recurrence and Optimization Objective of GatedFWA
  • proof
  • proof