Table of Contents
Fetching ...

RecurFormer: Not All Transformer Heads Need Self-Attention

Ruiqing Yan, Linghan Zheng, Xingbo Du, Han Zou, Yufeng Guo, Jianfei Yang

TL;DR

RecurFormer is a novel architecture that replaces these attention heads with linear recurrent neural networks (RNNs), specifically the Mamba architecture, which reduces the cache size without evicting tokens, thus maintaining generation quality and providing a practical solution to the computational challenges of Transformer-based LLMs inference.

Abstract

Transformer-based large language models (LLMs) excel in modeling complex language patterns but face significant computational costs during inference, especially with long inputs due to the attention mechanism's memory overhead. We observe that certain attention heads exhibit a distribution where the attention weights concentrate on tokens near the query token, termed as recency aware, which focuses on local and short-range dependencies. Leveraging this insight, we propose RecurFormer, a novel architecture that replaces these attention heads with linear recurrent neural networks (RNNs), specifically the Mamba architecture. This replacement reduces the cache size without evicting tokens, thus maintaining generation quality. RecurFormer retains the ability to model long-range dependencies through the remaining attention heads and allows for reusing pre-trained Transformer-based LLMs weights with continual training. Experiments demonstrate that RecurFormer matches the original model's performance while significantly enhancing inference efficiency. Our approach provides a practical solution to the computational challenges of Transformer-based LLMs inference, making it highly attractive for tasks involving long inputs.

RecurFormer: Not All Transformer Heads Need Self-Attention

TL;DR

RecurFormer is a novel architecture that replaces these attention heads with linear recurrent neural networks (RNNs), specifically the Mamba architecture, which reduces the cache size without evicting tokens, thus maintaining generation quality and providing a practical solution to the computational challenges of Transformer-based LLMs inference.

Abstract

Transformer-based large language models (LLMs) excel in modeling complex language patterns but face significant computational costs during inference, especially with long inputs due to the attention mechanism's memory overhead. We observe that certain attention heads exhibit a distribution where the attention weights concentrate on tokens near the query token, termed as recency aware, which focuses on local and short-range dependencies. Leveraging this insight, we propose RecurFormer, a novel architecture that replaces these attention heads with linear recurrent neural networks (RNNs), specifically the Mamba architecture. This replacement reduces the cache size without evicting tokens, thus maintaining generation quality. RecurFormer retains the ability to model long-range dependencies through the remaining attention heads and allows for reusing pre-trained Transformer-based LLMs weights with continual training. Experiments demonstrate that RecurFormer matches the original model's performance while significantly enhancing inference efficiency. Our approach provides a practical solution to the computational challenges of Transformer-based LLMs inference, making it highly attractive for tasks involving long inputs.

Paper Structure

This paper contains 29 sections, 3 equations, 8 figures, 5 tables, 1 algorithm.

Figures (8)

  • Figure 1: The left diagram shows how attention with recency aware updates values via weighted summation, where $f(x) = \text{softmax}(x/\sqrt{d_k})$, and $d_k$ is the dimension of key. The right diagram illustrates the linear RNNs update. $A$ represents the weights for state transitions, while $B_i$ and $C_i$ are input and output gates. Regions with darker shades of orange indicate a greater influence on $V_t^{\text{update}}$, such as representing higher attention weights, while lighter-colored regions have less influence.
  • Figure 2: L2 norm
  • Figure 3: Attention weight
  • Figure 4: Contribution
  • Figure 6: Loss values of RecurFormer and the corresponding original models during continual training on the masked prediction task using the Wikipedia English training set.
  • ...and 3 more figures