Table of Contents
Fetching ...

Conv-Basis: A New Paradigm for Efficient Attention Inference and Gradient Computation in Transformers

Yingyu Liang, Heshan Liu, Zhenmei Shi, Zhao Song, Zhuoyan Xu, Junze Yin

TL;DR

This work uses the convolution-like structure of attention matrices to develop an efficient approximation method for attention computation using convolution matrices, and proposes a $\mathsf{conv}$ basis system, analogous to the rank basis, and shows that any lower triangular matrix can always be decomposed as a sum of structured convolution matrices in this basis.

Abstract

The self-attention mechanism is the key to the success of transformers in recent Large Language Models (LLMs). However, the quadratic computational cost $O(n^2)$ in the input sequence length $n$ is a notorious obstacle for further improvement and scalability in longer contexts. In this work, we leverage the convolution-like structure of attention matrices to develop an efficient approximation method for attention computation using convolution matrices. We propose a $\mathsf{conv}$ basis system, analogous to the rank basis, and show that any lower triangular matrix can always be decomposed as a sum of structured convolution matrices in this basis. We then design a fast algorithm to approximate the attention matrix via a sum of such $k$ convolution matrices. This allows us to compute the attention {\it inference} via Fast Fourier Transforms (FFT) in $O(knd \log n)$ time, where $d$ is the hidden dimension, and thus achieve almost linear time $n^{1+o(1)}$ in the practical scenario where $kd = n^{o(1)}$. Furthermore, the attention {\it training forward} and {\it backward gradient} can be computed in $n^{1+o(1)}$ as well. We provide theoretical guarantees on the run time and approximation error and conduct preliminary experiments to evaluate its effectiveness. We hope our new paradigm for accelerating attention computation in transformer models can help their application to longer contexts.

Conv-Basis: A New Paradigm for Efficient Attention Inference and Gradient Computation in Transformers

TL;DR

This work uses the convolution-like structure of attention matrices to develop an efficient approximation method for attention computation using convolution matrices, and proposes a basis system, analogous to the rank basis, and shows that any lower triangular matrix can always be decomposed as a sum of structured convolution matrices in this basis.

Abstract

The self-attention mechanism is the key to the success of transformers in recent Large Language Models (LLMs). However, the quadratic computational cost in the input sequence length is a notorious obstacle for further improvement and scalability in longer contexts. In this work, we leverage the convolution-like structure of attention matrices to develop an efficient approximation method for attention computation using convolution matrices. We propose a basis system, analogous to the rank basis, and show that any lower triangular matrix can always be decomposed as a sum of structured convolution matrices in this basis. We then design a fast algorithm to approximate the attention matrix via a sum of such convolution matrices. This allows us to compute the attention {\it inference} via Fast Fourier Transforms (FFT) in time, where is the hidden dimension, and thus achieve almost linear time in the practical scenario where . Furthermore, the attention {\it training forward} and {\it backward gradient} can be computed in as well. We provide theoretical guarantees on the run time and approximation error and conduct preliminary experiments to evaluate its effectiveness. We hope our new paradigm for accelerating attention computation in transformer models can help their application to longer contexts.
Paper Structure (45 sections, 48 theorems, 111 equations, 4 figures, 6 algorithms)

This paper contains 45 sections, 48 theorems, 111 equations, 4 figures, 6 algorithms.

Key Result

Theorem 1.1

Let $\epsilon > 0$, $k\in [n]$ and $Q, K \in \mathbb{R}^{n \times d}$. If $QK^\top$ is $\epsilon$-close in $\ell_\infty$ norm to a matrix with $k$-$\mathsf{conv}$ basis (Definition def:non_degen), then we can solve the Exact Attention Computation (Definition def:exact_attention_mask) in $O(knd\log(n

Figures (4)

  • Figure 1: (a) In the left two figures, we compare the complexity of $\mathsf{conv}(a) \cdot w$ between the Naive way and FFT way, where random vector $a, w \in \mathbb{R}^n$ and $\mathsf{conv}(a) \in \mathbb{R}^{n \times n}$ (Definition \ref{['def:conv']}). The $x$-axis is the input token number $n$. The $y$-axis is the average CPU time/Float Operations (FLOPs) over $n$, in the first/second figure. The number reported is an average of 100 runs with Numpy implementation. It is clear to see the Naive way takes $O(n^2)$ while the FFT way takes $O(n \log n)$. (b) In the right figure, we plot one $QK^\top \in \mathbb{R}^{n \times n }$ in Llama3 llama3, where input is from the SST-2 glue with $n=47$ tokens. It is clear to see the $\mathsf{conv}$-like structure in the attention matrix.
  • Figure 2: A matrix with $3$-$\mathsf{conv}$ basis. We present an example of the matrix defined in Definition \ref{['def:conv_basis']} when $k = 3$. The matrix with $3$-$\mathsf{conv}$ basis is on the left-hand side of the equation in this figure. The red entries in this matrix come from the first matrix on the right-hand side. The purple entries in this matrix are the sum of the red entries from the first matrix on the right-hand side and the blue entries from the second matrix on the right-hand side. The dark green entries are equal to the sum of red, green, and blue entries from the matrices on the right-hand side.
  • Figure 3: A $16 \times 16$ matrix with, left - row change by amortized constant mask (Definition \ref{['def:mask_constant']}); middle - continuous row mask (Definition \ref{['def:mask_continuous']}); right - distinct $3$ rows mask (Definition \ref{['def:mask_r_row']}). Green means $1$ and yellow means $0$.
  • Figure 4: The comparison between the Llama3 8B Instruct with or without using our Algorithm \ref{['alg:conv_forward']} on the IMDB dataset. The input sequence length $n=2048$. The $x$-axis is the number of $\mathsf{conv}$ basis. The $y$ axis is relative difference $\frac{\|Y-\widetilde{Y}\|_F^2}{\|Y\|_F^2}$ for the left figure and classification accuracy for the right figure. Note that $k=2048$ represents the baseline of the original model, as this is the input sequence length.

Theorems & Definitions (143)

  • Theorem 1.1: Main result, informal version of Theorem \ref{['thm:conv_formal']}
  • Definition 3.1: Input and weight matrix
  • Definition 3.2: Causal attention mask
  • Definition 3.3: Exact attention computation
  • Remark 3.4
  • Definition 3.5: Convolution matrix
  • Claim 3.6
  • Claim 3.7
  • Claim 3.8
  • Definition 3.9: Sub-convolution matrix
  • ...and 133 more