Table of Contents
Fetching ...

FAST: Factorizable Attention for Speeding up Transformers

Armin Gerami, Monte Hoover, Pranav S. Dulepet, Ramani Duraiswami

TL;DR

The paper tackles the quadratic complexity of standard transformer attention by introducing FAST, a factorable, linearly scaling attention mechanism called Fastmax that preserves full all-to-all interactions. It leverages a polynomial kernel $f(x)=\sum_{\ell=0}^p x^\ell/\ell!$ with normalized queries and keys to derive a computable and differentiable attention metric, supported by a gradient bound that promotes stable training. The method achieves per-head complexity $O(ND^{p+1})$ with memory $O(ND^p+D^{p+1})$ (unmasked) and scales to $O(NH(C/H)^{p+1})$ across all heads; a memory-efficient variant using custom gradients reduces to $O(ND^{p-1})$ per head. Empirical results on MNIST and Long Range Arena show that Fastmax matches Softmax in expressivity while providing linear scalability, enabling efficient long-context modeling and potential deployment on edge devices, with future work exploring higher-order terms and CPU implementations.

Abstract

Motivated by the factorization inherent in the original fast multipole method and the improved fast Gauss transform we introduce a factorable form of attention that operates efficiently in high dimensions. This approach reduces the computational and memory complexity of the attention mechanism in transformers from $O(N^2)$ to $O(N)$. In comparison to previous attempts, our work presents a linearly scaled attention mechanism that maintains the full representation of the attention matrix without compromising on sparsification and incorporates the all-to-all relationship between tokens. We explore the properties of our new attention metric and conduct tests in various standard settings. Results indicate that our attention mechanism has a robust performance and holds significant promise for diverse applications where self-attention is used.

FAST: Factorizable Attention for Speeding up Transformers

TL;DR

The paper tackles the quadratic complexity of standard transformer attention by introducing FAST, a factorable, linearly scaling attention mechanism called Fastmax that preserves full all-to-all interactions. It leverages a polynomial kernel with normalized queries and keys to derive a computable and differentiable attention metric, supported by a gradient bound that promotes stable training. The method achieves per-head complexity with memory (unmasked) and scales to across all heads; a memory-efficient variant using custom gradients reduces to per head. Empirical results on MNIST and Long Range Arena show that Fastmax matches Softmax in expressivity while providing linear scalability, enabling efficient long-context modeling and potential deployment on edge devices, with future work exploring higher-order terms and CPU implementations.

Abstract

Motivated by the factorization inherent in the original fast multipole method and the improved fast Gauss transform we introduce a factorable form of attention that operates efficiently in high dimensions. This approach reduces the computational and memory complexity of the attention mechanism in transformers from to . In comparison to previous attempts, our work presents a linearly scaled attention mechanism that maintains the full representation of the attention matrix without compromising on sparsification and incorporates the all-to-all relationship between tokens. We explore the properties of our new attention metric and conduct tests in various standard settings. Results indicate that our attention mechanism has a robust performance and holds significant promise for diverse applications where self-attention is used.
Paper Structure (16 sections, 22 equations, 6 figures, 2 tables)

This paper contains 16 sections, 22 equations, 6 figures, 2 tables.

Figures (6)

  • Figure 1: Flowchart of calculating Score using Fastmax. The purple terms on the upper-left of each step indicate their computational cost. The backward pass is computed using automatic differentiation, but further optimizations is possible using custom gradient (see § \ref{['custom_gradients']}).
  • Figure 2: Empirical results of different dropout approaches. Notice that even small amounts of dropout on the quadratic term benefit test generalization.
  • Figure 3: Comparison between times taken for calculating the Scores per head using Fastmax and Softmax on an RTX A6000 (48 GB memory) for various dimension per heads $D$. Softmax scales quadratically with number of tokens $N$, whereas Fastmax scales linearly. The 'x' marks indicate an "out of memory" condition.
  • Figure 4: Attention maps from transformers trained on MNIST and Tiny Shakespeare data: (a) Softmax attention on MNIST, (b) Fastmax$2$ attention on MNIST, (c) Softmax attention on Tiny Shakespeare, (d) Fastmax attention on Tiny Shakespeare. Note that the mechanisms produce different scores, but both converge on training.
  • Figure 5: Long Range Arena results, after tayLongRangeArena2020, showing speed versus accuracy for alternate transformer formulations, including the vanilla Transformer (softmax), and several others discussed in the paper, as well as the two Fastmax variants proposed in this paper. GPU memory usage is represented by circle area. The timings were measured on a RTX A5000 GPU. The hyperparameters for each algorithm were optimized. Further details are included in Tables \ref{['tab:lra-acc']} and \ref{['tab:lra-speed']}
  • ...and 1 more figures