Table of Contents
Fetching ...

Fast Transformers with Clustered Attention

Apoorv Vyas, Angelos Katharopoulos, François Fleuret

TL;DR

This paper tackles the quadratic complexity of self-attention in Transformers by introducing clustered attention, which groups queries into C clusters and computes attention at cluster centroids before broadcasting to members. It further improves accuracy by selectively recomputing attention for the top-k keys per cluster, supported by theoretical bounds on approximation error. Empirically, the approach yields linear-time performance with fixed clusters and outperforms vanilla transformers under the same compute on WSJ and Switchboard ASR tasks, while also capable of approximating pretrained models like RoBERTa with as few as 25 clusters. The results suggest substantial practical benefits for long-sequence modeling and efficient transfer to large-scale pretrained models, with potential for broader deployment in resource-constrained settings.

Abstract

Transformers have been proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitively expensive for large sequences. To address this, we propose clustered attention, which instead of computing the attention for every query, groups queries into clusters and computes attention just for the centroids. To further improve this approximation, we use the computed clusters to identify the keys with the highest attention per query and compute the exact key/query dot products. This results in a model with linear complexity with respect to the sequence length for a fixed number of clusters. We evaluate our approach on two automatic speech recognition datasets and show that our model consistently outperforms vanilla transformers for a given computational budget. Finally, we demonstrate that our model can approximate arbitrarily complex attention distributions with a minimal number of clusters by approximating a pretrained BERT model on GLUE and SQuAD benchmarks with only 25 clusters and no loss in performance.

Fast Transformers with Clustered Attention

TL;DR

This paper tackles the quadratic complexity of self-attention in Transformers by introducing clustered attention, which groups queries into C clusters and computes attention at cluster centroids before broadcasting to members. It further improves accuracy by selectively recomputing attention for the top-k keys per cluster, supported by theoretical bounds on approximation error. Empirically, the approach yields linear-time performance with fixed clusters and outperforms vanilla transformers under the same compute on WSJ and Switchboard ASR tasks, while also capable of approximating pretrained models like RoBERTa with as few as 25 clusters. The results suggest substantial practical benefits for long-sequence modeling and efficient transfer to large-scale pretrained models, with potential for broader deployment in resource-constrained settings.

Abstract

Transformers have been proven a successful model for a variety of tasks in sequence modeling. However, computing the attention matrix, which is their key component, has quadratic complexity with respect to the sequence length, thus making them prohibitively expensive for large sequences. To address this, we propose clustered attention, which instead of computing the attention for every query, groups queries into clusters and computes attention just for the centroids. To further improve this approximation, we use the computed clusters to identify the keys with the highest attention per query and compute the exact key/query dot products. This results in a model with linear complexity with respect to the sequence length for a fixed number of clusters. We evaluate our approach on two automatic speech recognition datasets and show that our model consistently outperforms vanilla transformers for a given computational budget. Finally, we demonstrate that our model can approximate arbitrarily complex attention distributions with a minimal number of clusters by approximating a pretrained BERT model on GLUE and SQuAD benchmarks with only 25 clusters and no loss in performance.

Paper Structure

This paper contains 28 sections, 3 theorems, 26 equations, 8 figures, 4 tables.

Key Result

Proposition 1

Given two queries $Q_i$ and $Q_j$ such that $\left\|Q_i - Q_j\right\|_2 \leq\epsilon$, where $\left\|K\right\|_2$ denotes the spectral norm of $K$.

Figures (8)

  • Figure 1: We compare the achieved performance of various transformer models under an equalized computational budget. The numbers near the datapoints denote the number of layers and number of clusters or hashing rounds where applicable. i-clustered is consistently better than all baselines for a given computational budget both in WSJ and Switchboard datasets. The details can be found in § \ref{['subsec:asr_wsj']} and § \ref{['subsec:asr_swbd']} respectively.
  • Figure 2: Flow-chart demonstrating the compuation for clustered attention. We use different colors to represent the query groups and the computed centroids. The same colors are then used to show the attention weights $A^c$, new values for the centroids $\hat{V}^c$, and the resulting values $\hat{V}$ after broadcasting. For more details refer to § \ref{['subsec:supp_clustered_attn']} or § 3.2 in the main paper.
  • Figure 3: Flow-chart demonstrating the compuation for i-clustered attention. The lower half of the figure shows the new value $\hat{V}^t$ computed by sparse dot-products with the keys $K$ and values $V$ corresponding to the the top-$k$ keys in $T$. The top half of the figure shows the computation for $\hat{V}^b$ which is the weighted average of the rest of the values with weights coming from the clustered attention $A^c$. The resulting values $\hat{V}$ is the sum of $\hat{V}^b$ and $\hat{V}^t$. For more details refer § \ref{['subsec:supp_improved_attn']} or to the § 3.3 in the main paper.
  • Figure 4: Per element GPU time and memory consumption for a forward/backward pass. All models, except full, scale linearly with respect to the sequence length since they have constant time and memory per element. Detailed analysis can be found in § \ref{['subsec:supp_benchmark']}.
  • Figure 5: The heatmaps depict the achieved accuracy on an artificial copy task (§ \ref{['sec:supp_ablation']}) as the sequence length, the number of clusters and the number of hashing rounds varies. Improved clustered (\ref{['fig:supp_ablation_ic']}) is the only fast transformer variant that can solve the task perfectly for any sequence length and number of clusters combination.
  • ...and 3 more figures

Theorems & Definitions (5)

  • Proposition 1
  • proof
  • Proposition 2
  • Proposition 3
  • proof