Fast Transformers with Clustered Attention
Apoorv Vyas, Angelos Katharopoulos, François Fleuret
TL;DR
This paper tackles the quadratic complexity of self-attention in Transformers by introducing clustered attention, which groups queries into C clusters and computes attention at cluster centroids before broadcasting to members. It further improves accuracy by selectively recomputing attention for the top-k keys per cluster, supported by theoretical bounds on approximation error. Empirically, the approach yields linear-time performance with fixed clusters and outperforms vanilla transformers under the same compute on WSJ and Switchboard ASR tasks, while also capable of approximating pretrained models like RoBERTa with as few as 25 clusters. The results suggest substantial practical benefits for long-sequence modeling and efficient transfer to large-scale pretrained models, with potential for broader deployment in resource-constrained settings.
Abstract
Transformers have been proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitively expensive for large sequences. To address this, we propose clustered attention, which instead of computing the attention for every query, groups queries into clusters and computes attention just for the centroids. To further improve this approximation, we use the computed clusters to identify the keys with the highest attention per query and compute the exact key/query dot products. This results in a model with linear complexity with respect to the sequence length for a fixed number of clusters. We evaluate our approach on two automatic speech recognition datasets and show that our model consistently outperforms vanilla transformers for a given computational budget. Finally, we demonstrate that our model can approximate arbitrarily complex attention distributions with a minimal number of clusters by approximating a pretrained BERT model on GLUE and SQuAD benchmarks with only 25 clusters and no loss in performance.
