Table of Contents
Fetching ...

MetaLA: Unified Optimal Linear Approximation to Softmax Attention Map

Yuhong Chou, Man Yao, Kexin Wang, Yuqi Pan, Ruijie Zhu, Yiran Zhong, Yu Qiao, Jibin Wu, Bo Xu, Guoqi Li

TL;DR

This work proposes Meta Linear Attention (MetaLA) as a solution that satisfies three conditions for the optimal linear attention design and finds that none of the current linear models meet all three conditions, resulting in suboptimal performance.

Abstract

Various linear complexity models, such as Linear Transformer (LinFormer), State Space Model (SSM), and Linear RNN (LinRNN), have been proposed to replace the conventional softmax attention in Transformer structures. However, the optimal design of these linear models is still an open question. In this work, we attempt to answer this question by finding the best linear approximation to softmax attention from a theoretical perspective. We start by unifying existing linear complexity models as the linear attention form and then identify three conditions for the optimal linear attention design: 1) Dynamic memory ability; 2) Static approximation ability; 3) Least parameter approximation. We find that none of the current linear models meet all three conditions, resulting in suboptimal performance. Instead, we propose Meta Linear Attention (MetaLA) as a solution that satisfies these conditions. Our experiments on Multi-Query Associative Recall (MQAR) task, language modeling, image classification, and Long-Range Arena (LRA) benchmark demonstrate that MetaLA is more effective than the existing linear models.

MetaLA: Unified Optimal Linear Approximation to Softmax Attention Map

TL;DR

This work proposes Meta Linear Attention (MetaLA) as a solution that satisfies three conditions for the optimal linear attention design and finds that none of the current linear models meet all three conditions, resulting in suboptimal performance.

Abstract

Various linear complexity models, such as Linear Transformer (LinFormer), State Space Model (SSM), and Linear RNN (LinRNN), have been proposed to replace the conventional softmax attention in Transformer structures. However, the optimal design of these linear models is still an open question. In this work, we attempt to answer this question by finding the best linear approximation to softmax attention from a theoretical perspective. We start by unifying existing linear complexity models as the linear attention form and then identify three conditions for the optimal linear attention design: 1) Dynamic memory ability; 2) Static approximation ability; 3) Least parameter approximation. We find that none of the current linear models meet all three conditions, resulting in suboptimal performance. Instead, we propose Meta Linear Attention (MetaLA) as a solution that satisfies these conditions. Our experiments on Multi-Query Associative Recall (MQAR) task, language modeling, image classification, and Long-Range Arena (LRA) benchmark demonstrate that MetaLA is more effective than the existing linear models.

Paper Structure

This paper contains 23 sections, 3 theorems, 31 equations, 5 figures, 17 tables.

Key Result

Proposition A2.1

Only models with dynamic decay can satisfy C1 (Dynamic memory ability). Let $\mathbf{S}_t\in\mathcal{R}^{d_k\times {{d}_{v}}}$ be hidden state of general linear attention (see sec:general_form). At a time $t$, the information about $\mathbf{v}_{t_1},\dots,\mathbf{v}_{t_{d_k}}$ is successfully stored

Figures (5)

  • Figure 1: General Form of LinFormer/SSM/LinRNN Mechanisms. The general form equips with two modes of parallel and recurrent computation which enjoys both training and inference efficiency.
  • Figure 2: Recurrent form of MetaLA. We mark all three enhancements in red.
  • Figure 3: Accuracy (%) on the synthetic MQAR task.
  • Figure A1: MetaLA Transformer. Stacking $N$ MetaLA blocks, each block is composed of two modules in sequence: token mixer and channel mixer.
  • Figure A2: Training efficiency evaluations. The throughput and memory usage on a single A800 GPU of Transformer and various linear models. Transformer++ is implemented using FlashAttention flashattention and SwiGLU.

Theorems & Definitions (7)

  • Definition 4.1
  • Proposition A2.1
  • proof
  • Proposition A2.2
  • proof
  • Proposition A2.3
  • proof