Table of Contents
Fetching ...

SEA: Sparse Linear Attention with Estimated Attention Mask

Heejun Lee, Jina Kim, Jeffrey Willette, Sung Ju Hwang

TL;DR

SEA is proposed: Sparse linear attention with an Estimated Attention mask, which estimates the attention matrix with linear complexity via kernel-based linear attention, then subsequently creates a sparse attention matrix with a top-k selection to perform a sparse attention operation.

Abstract

The transformer architecture has driven breakthroughs in recent years on tasks which require modeling pairwise relationships between sequential elements, as is the case in natural language understanding. However, long seqeuences pose a problem due to the quadratic complexity of the attention operation. Previous research has aimed to lower the complexity by sparsifying or linearly approximating the attention matrix. Yet, these approaches cannot straightforwardly distill knowledge from a teacher's attention matrix and often require complete retraining from scratch. Furthermore, previous sparse and linear approaches lose interpretability if they cannot produce full attention matrices. To address these challenges, we propose SEA: Sparse linear attention with an Estimated Attention mask. SEA estimates the attention matrix with linear complexity via kernel-based linear attention, then subsequently creates a sparse attention matrix with a top-k selection to perform a sparse attention operation. For language modeling tasks (Wikitext2), previous linear and sparse attention methods show roughly two-fold worse perplexity scores over the quadratic OPT-1.3B baseline, while SEA achieves better perplexity than OPT-1.3B, using roughly half the memory of OPT-1.3B, providing interpretable attention matrix. We believe that our work will have a large practical impact, as it opens the possibility of running large transformers on resource-limited devices with less memory.

SEA: Sparse Linear Attention with Estimated Attention Mask

TL;DR

SEA is proposed: Sparse linear attention with an Estimated Attention mask, which estimates the attention matrix with linear complexity via kernel-based linear attention, then subsequently creates a sparse attention matrix with a top-k selection to perform a sparse attention operation.

Abstract

The transformer architecture has driven breakthroughs in recent years on tasks which require modeling pairwise relationships between sequential elements, as is the case in natural language understanding. However, long seqeuences pose a problem due to the quadratic complexity of the attention operation. Previous research has aimed to lower the complexity by sparsifying or linearly approximating the attention matrix. Yet, these approaches cannot straightforwardly distill knowledge from a teacher's attention matrix and often require complete retraining from scratch. Furthermore, previous sparse and linear approaches lose interpretability if they cannot produce full attention matrices. To address these challenges, we propose SEA: Sparse linear attention with an Estimated Attention mask. SEA estimates the attention matrix with linear complexity via kernel-based linear attention, then subsequently creates a sparse attention matrix with a top-k selection to perform a sparse attention operation. For language modeling tasks (Wikitext2), previous linear and sparse attention methods show roughly two-fold worse perplexity scores over the quadratic OPT-1.3B baseline, while SEA achieves better perplexity than OPT-1.3B, using roughly half the memory of OPT-1.3B, providing interpretable attention matrix. We believe that our work will have a large practical impact, as it opens the possibility of running large transformers on resource-limited devices with less memory.
Paper Structure (31 sections, 1 equation, 11 figures, 10 tables)

This paper contains 31 sections, 1 equation, 11 figures, 10 tables.

Figures (11)

  • Figure 1: Concept. We estimate the attention matrix in a compressed size ($\hat{{\bm{A}}}$), then perform a grouped top-$\hat{k}$ selection, and subsequently perform sparse attention with our novel FlatCSR operation using an estimated attention mask on the full attention matrix. SEA has linear complexity in all steps at test-time, and requires direct attention matrix distillation from the quadratic teacher at train-time.
  • Figure 1: Ablation study on grouped top-$\hat{k}$ modes
  • Figure 2: Visualization of Input and Output of CNN Decoder
  • Figure 4: (left) Intermediate attention examples. (right) The first row is the attention probability of the teacher model, and the second row is the compressed attention interpolated to full size. Interpolation to the full size attention matrix is for visualizing our estimated attention $\hat{{\bm{A}}}$ and is not part of the regular linear inference procedure. (a) MNLI with BERT-base ($K = 128$) (b) Wikitext2 with OPT-125m ($K = 256$).
  • Figure 5: Visualization of intermediate buffers during masking and sparse attention.
  • ...and 6 more figures