FAST: Factorizable Attention for Speeding up Transformers
Armin Gerami, Monte Hoover, Pranav S. Dulepet, Ramani Duraiswami
TL;DR
The paper tackles the quadratic complexity of standard transformer attention by introducing FAST, a factorable, linearly scaling attention mechanism called Fastmax that preserves full all-to-all interactions. It leverages a polynomial kernel $f(x)=\sum_{\ell=0}^p x^\ell/\ell!$ with normalized queries and keys to derive a computable and differentiable attention metric, supported by a gradient bound that promotes stable training. The method achieves per-head complexity $O(ND^{p+1})$ with memory $O(ND^p+D^{p+1})$ (unmasked) and scales to $O(NH(C/H)^{p+1})$ across all heads; a memory-efficient variant using custom gradients reduces to $O(ND^{p-1})$ per head. Empirical results on MNIST and Long Range Arena show that Fastmax matches Softmax in expressivity while providing linear scalability, enabling efficient long-context modeling and potential deployment on edge devices, with future work exploring higher-order terms and CPU implementations.
Abstract
Motivated by the factorization inherent in the original fast multipole method and the improved fast Gauss transform we introduce a factorable form of attention that operates efficiently in high dimensions. This approach reduces the computational and memory complexity of the attention mechanism in transformers from $O(N^2)$ to $O(N)$. In comparison to previous attempts, our work presents a linearly scaled attention mechanism that maintains the full representation of the attention matrix without compromising on sparsification and incorporates the all-to-all relationship between tokens. We explore the properties of our new attention metric and conduct tests in various standard settings. Results indicate that our attention mechanism has a robust performance and holds significant promise for diverse applications where self-attention is used.
