Table of Contents
Fetching ...

LevAttention: Time, Space, and Streaming Efficient Algorithm for Heavy Attentions

Ravindran Kannan, Chiranjib Bhattacharyya, Praneeth Kacham, David P. Woodruff

TL;DR

The attention mechanism that uses only the subset of keys in the universal set as LevAttention is called, and it is proved that for any K there is a ``universal set" U \subset [n]$ of size independent of $n, such that for any $Q$ and any row $i$, the large attention scores all have $j \in U$.

Abstract

A central problem related to transformers can be stated as follows: given two $n \times d$ matrices $Q$ and $K$, and a non-negative function $f$, define the matrix $A$ as follows: (1) apply the function $f$ to each entry of the $n \times n$ matrix $Q K^T$, and then (2) normalize each of the row sums of $A$ to be equal to $1$. The matrix $A$ can be computed in $O(n^2 d)$ time assuming $f$ can be applied to a number in constant time, but the quadratic dependence on $n$ is prohibitive in applications where it corresponds to long context lengths. For a large class of functions $f$, we show how to find all the ``large attention scores", i.e., entries of $A$ which are at least a positive value $\varepsilon$, in time with linear dependence on $n$ (i.e., $n \cdot \textrm{poly}(d/\varepsilon)$) for a positive parameter $\varepsilon > 0$. Our class of functions include all functions $f$ of the form $f(x) = |x|^p$, as explored recently in transformer models. Using recently developed tools from randomized numerical linear algebra, we prove that for any $K$, there is a ``universal set" $U \subset [n]$ of size independent of $n$, such that for any $Q$ and any row $i$, the large attention scores $A_{i,j}$ in row $i$ of $A$ all have $j \in U$. We also find $U$ in $n \cdot \textrm{poly}(d/\varepsilon)$ time. Notably, we (1) make no assumptions on the data, (2) our workspace does not grow with $n$, and (3) our algorithms can be computed in streaming and parallel settings. We call the attention mechanism that uses only the subset of keys in the universal set as LevAttention since our algorithm to identify the universal set $U$ is based on leverage scores. We empirically show the benefits of our scheme for vision transformers, showing how to train new models that use our universal set while training as well, showing that our model is able to consistently select ``important keys'' during training.

LevAttention: Time, Space, and Streaming Efficient Algorithm for Heavy Attentions

TL;DR

The attention mechanism that uses only the subset of keys in the universal set as LevAttention is called, and it is proved that for any K there is a ``universal set" U \subset [n]n, such that for any and any row , the large attention scores all have .

Abstract

A central problem related to transformers can be stated as follows: given two matrices and , and a non-negative function , define the matrix as follows: (1) apply the function to each entry of the matrix , and then (2) normalize each of the row sums of to be equal to . The matrix can be computed in time assuming can be applied to a number in constant time, but the quadratic dependence on is prohibitive in applications where it corresponds to long context lengths. For a large class of functions , we show how to find all the ``large attention scores", i.e., entries of which are at least a positive value , in time with linear dependence on (i.e., ) for a positive parameter . Our class of functions include all functions of the form , as explored recently in transformer models. Using recently developed tools from randomized numerical linear algebra, we prove that for any , there is a ``universal set" of size independent of , such that for any and any row , the large attention scores in row of all have . We also find in time. Notably, we (1) make no assumptions on the data, (2) our workspace does not grow with , and (3) our algorithms can be computed in streaming and parallel settings. We call the attention mechanism that uses only the subset of keys in the universal set as LevAttention since our algorithm to identify the universal set is based on leverage scores. We empirically show the benefits of our scheme for vision transformers, showing how to train new models that use our universal set while training as well, showing that our model is able to consistently select ``important keys'' during training.
Paper Structure (15 sections, 8 theorems, 11 equations, 4 figures, 1 table)

This paper contains 15 sections, 8 theorems, 11 equations, 4 figures, 1 table.

Key Result

Theorem 1.1

Let $f$ be a non-negative function and let $\Psi^f = \sup_K \sum_{i=1}^n \sigma^f_i(K)$. There is a subset $U \subset [n]$ of size $\Psi^f / \varepsilon$ so that for any query matrix $Q$ and $i \in \{1, 2, \ldots, n\}$, if $A_{i,j} \geq \varepsilon$, then $j \in U$.

Figures (4)

  • Figure 1: Example of an image partitioned into 196 patches using a 14 x 14 grid. The green patches represent the neighbors of the red patch within a Manhattan distance of 3.
  • Figure 2: Histograms of top-32 attention weights. Each of the histograms plots the distribution of the sum of top-32 attention weights for query tokens. We note that at many attention heads, for many query tokens, the largest 32 attention weights constitute a significant fraction of the total attention weight.
  • Figure 3: Histograms of attention weights captured by local tokens. Each of the histograms plots the distribution of attention weight captured by keys within a Manhattan distance of 3.
  • Figure 4: Histograms of Attention weights captured by the important keys along with local tokens. As discussed in Section \ref{['subsec:structural-properties']}, important keys are defined as the 32 keys capturing the largest attention weight at a given attention head. The histogram shows that at a large number of attention heads, the important keys together with local keys are able to capture a significant fraction of the attention weight for a large number of queries.

Theorems & Definitions (18)

  • Theorem 1.1
  • Remark 2.1
  • Theorem 2.2
  • Lemma 2.3
  • proof
  • Theorem 2.4
  • proof
  • Theorem 3.1
  • proof
  • Theorem 3.2
  • ...and 8 more