Table of Contents
Fetching ...

Enhancing Linear Attention with Residual Learning

Xunhao Lai, Jialiang Kang, Jianqiao Lu, Tong Lin, Pengyu Zhao

TL;DR

Transformer self-attention is quadratically costly for long sequences. The authors reframe linear attention as a recurrent process and introduce Residual Linear Attention (RLA) with an explicit residual state to correct base predictions, plus Residual Delta Net (RDN) as a delta-rule variant with adaptive gating and residual clipping. The core idea is explicit residual fitting, via an auxiliary state $\boldsymbol{R}_t$ and a clipped residual $\boldsymbol{r}_t$, yielding outputs $\boldsymbol{o}_t = \boldsymbol{S}_{t-1} \boldsymbol{q}_t + \boldsymbol{R}_t \boldsymbol{q}_t$ (RLA) or its delta-rule counterpart (RDN). Through kernel-optimized, linear-time implementations, RLA and RDN outperform contemporary linear-attention baselines across language modeling and recall-intensive tasks while remaining close to Transformer performance. The framework is modular and can be applied to various linear backbones, offering a scalable path to stronger long-sequence models.

Abstract

Linear attention offers a linear-time alternative to self-attention but often struggles to capture long-range patterns. We revisit linear attention through a prediction-correction lens and show that prevalent variants can be written as a combination of a historical prediction and a single-token correction, which creates an expressivity bottleneck. To address this bottleneck, we introduce Residual Linear Attention (RLA), a framework that equips linear attention with an explicit residual-fitting mechanism. RLA maintains an auxiliary recurrent state that learns to accumulate residual errors over time and correct the base prediction. We further instantiate a delta-rule version, Residual Delta Net (RDN), incorporating adaptive gating and residual clipping for enhanced correction control and stability. Our implementation leverages highly optimized linear attention kernels and preserves linear time and memory. Across language modeling and recall-intensive evaluations, RLA and RDN consistently outperform their respective baselines and other modern linear-attention methods, narrowing the gap to standard Transformers while retaining linear scaling.

Enhancing Linear Attention with Residual Learning

TL;DR

Transformer self-attention is quadratically costly for long sequences. The authors reframe linear attention as a recurrent process and introduce Residual Linear Attention (RLA) with an explicit residual state to correct base predictions, plus Residual Delta Net (RDN) as a delta-rule variant with adaptive gating and residual clipping. The core idea is explicit residual fitting, via an auxiliary state and a clipped residual , yielding outputs (RLA) or its delta-rule counterpart (RDN). Through kernel-optimized, linear-time implementations, RLA and RDN outperform contemporary linear-attention baselines across language modeling and recall-intensive tasks while remaining close to Transformer performance. The framework is modular and can be applied to various linear backbones, offering a scalable path to stronger long-sequence models.

Abstract

Linear attention offers a linear-time alternative to self-attention but often struggles to capture long-range patterns. We revisit linear attention through a prediction-correction lens and show that prevalent variants can be written as a combination of a historical prediction and a single-token correction, which creates an expressivity bottleneck. To address this bottleneck, we introduce Residual Linear Attention (RLA), a framework that equips linear attention with an explicit residual-fitting mechanism. RLA maintains an auxiliary recurrent state that learns to accumulate residual errors over time and correct the base prediction. We further instantiate a delta-rule version, Residual Delta Net (RDN), incorporating adaptive gating and residual clipping for enhanced correction control and stability. Our implementation leverages highly optimized linear attention kernels and preserves linear time and memory. Across language modeling and recall-intensive evaluations, RLA and RDN consistently outperform their respective baselines and other modern linear-attention methods, narrowing the gap to standard Transformers while retaining linear scaling.

Paper Structure

This paper contains 33 sections, 18 equations, 4 figures, 7 tables.

Figures (4)

  • Figure 1: The architecture of our proposed model. The model structure (left) consists of $N$ stacked blocks. The detailed Attention Block (right) illustrates our core mechanism. Our primary contribution, the explicit residual fitting process, is highlighted in purple dash lines. This path computes the clipped residual $\bm{r}_{t} = \text{Clip}(\bm{v}_{t} - \bm{S}_{t-1}\bm{k}_{t})$, which is then modulated by a dedicated correction factor $\gamma_{t}=\sigma(\bm{W}_{\gamma}\bm{x})$ to dynamically correct the base prediction from the model's primary state. The model also utilizes gates $\alpha_{t}=\exp(-a\,\text{softplus}(\bm{W}_{\alpha}\bm{x}+b))$ and $\beta_{t}=\sigma(\bm{W}_{\beta}\bm{x})$ to control the state dynamics, where $a$ and $b$ are learnable scalars.
  • Figure 2: Comparison of attention kernel computation time (left) and model throughput (right) with respect to sequence length.
  • Figure 3: Ablation study on the correction factor $\gamma$. Using a dedicated $\gamma$ consistently lowers validation loss compared to tying it to $\beta$. The evaluation uses the same benchmarks as in \ref{['subsec:main_result']}, divided into three task types, and confirms that a dedicated $\gamma$ improves performance across several categories.
  • Figure 4: Ablation study on normalization and residual clipping. RLA variants without normalization or clipping exhibit exploding activation norms, indicating training instability. This instability leads to a higher training loss, highlighting that both components are crucial for stable training and better performance. In contrast, residual clipping has a negligible impact on the RDN training process.