Table of Contents
Fetching ...

CAST: Clustering Self-Attention using Surrogate Tokens for Efficient Transformers

Adjorn van Engelenhoven, Nicola Strisciuglio, Estefanía Talavera

TL;DR

CAST introduces learnable surrogate tokens to form cluster affiliations for self-attention, replacing quadratic $O(N^2)$ computation with a linearized $O(αN)$ approach while preserving information flow through intra-cluster attention and inter-cluster cluster-summaries. The method supports single-head and multi-head variants, two clustering regimes (Top-K and SA Top-K), and a complexity-optimizing relationship among cluster size and count. Empirical results on Long Range Arena show CAST offers substantial speedups and memory savings with competitive accuracy, performing particularly well on image-related tasks and presenting a trade-off on certain tasks like Pathfinder. Overall, CAST represents a practical, efficient transformer variant that maintains core attention capabilities while enabling scalable long-range sequence modeling, with future work including decoder adaptations and further clustering optimizations.

Abstract

The Transformer architecture has shown to be a powerful tool for a wide range of tasks. It is based on the self-attention mechanism, which is an inherently computationally expensive operation with quadratic computational complexity: memory usage and compute time increase quadratically with the length of the input sequences, thus limiting the application of Transformers. In this work, we propose a novel Clustering self-Attention mechanism using Surrogate Tokens (CAST), to optimize the attention computation and achieve efficient transformers. CAST utilizes learnable surrogate tokens to construct a cluster affinity matrix, used to cluster the input sequence and generate novel cluster summaries. The self-attention from within each cluster is then combined with the cluster summaries of other clusters, enabling information flow across the entire input sequence. CAST improves efficiency by reducing the complexity from $O(N^2)$ to $O(αN)$ where N is the sequence length, and α is constant according to the number of clusters and samples per cluster. We show that CAST performs better than or comparable to the baseline Transformers on long-range sequence modeling tasks, while also achieving higher results on time and memory efficiency than other efficient transformers.

CAST: Clustering Self-Attention using Surrogate Tokens for Efficient Transformers

TL;DR

CAST introduces learnable surrogate tokens to form cluster affiliations for self-attention, replacing quadratic computation with a linearized approach while preserving information flow through intra-cluster attention and inter-cluster cluster-summaries. The method supports single-head and multi-head variants, two clustering regimes (Top-K and SA Top-K), and a complexity-optimizing relationship among cluster size and count. Empirical results on Long Range Arena show CAST offers substantial speedups and memory savings with competitive accuracy, performing particularly well on image-related tasks and presenting a trade-off on certain tasks like Pathfinder. Overall, CAST represents a practical, efficient transformer variant that maintains core attention capabilities while enabling scalable long-range sequence modeling, with future work including decoder adaptations and further clustering optimizations.

Abstract

The Transformer architecture has shown to be a powerful tool for a wide range of tasks. It is based on the self-attention mechanism, which is an inherently computationally expensive operation with quadratic computational complexity: memory usage and compute time increase quadratically with the length of the input sequences, thus limiting the application of Transformers. In this work, we propose a novel Clustering self-Attention mechanism using Surrogate Tokens (CAST), to optimize the attention computation and achieve efficient transformers. CAST utilizes learnable surrogate tokens to construct a cluster affinity matrix, used to cluster the input sequence and generate novel cluster summaries. The self-attention from within each cluster is then combined with the cluster summaries of other clusters, enabling information flow across the entire input sequence. CAST improves efficiency by reducing the complexity from to where N is the sequence length, and α is constant according to the number of clusters and samples per cluster. We show that CAST performs better than or comparable to the baseline Transformers on long-range sequence modeling tasks, while also achieving higher results on time and memory efficiency than other efficient transformers.
Paper Structure (44 sections, 6 equations, 9 figures, 6 tables, 1 algorithm)

This paper contains 44 sections, 6 equations, 9 figures, 6 tables, 1 algorithm.

Figures (9)

  • Figure 1: Sketch of the proposed method. The colors red and blue correspond to two different clusters. With the queries ($Q$), keys ($K$), and values ($V$), we create the surrogate token similarities $A_q$ (similarity between the queries and surrogate tokens) and $A_k$ (similarity between the keys and surrogate tokens). They are combined to create a final similarity $A_g$ for each token to each cluster. We then use this clustering of tokens and create the clustered queries ($Q_g$), keys ($K_g$), and values ($V_g$). Within each cluster, self-attention is applied resulting in $R_{intra}$. Furthermore, $A_k$ is also clustered and matrix multiplied with $V_g$ to create a summary per cluster resulting in $R_{inter}$. The results $R_{intra}$ and $R_{inter}$ are then combined using $A_q$ as the weights for a weighted sum, resulting in $R$. Another linear projection $O$ is then applied on $R$ and passed on to the feedforward layer of the Transformer.
  • Figure 2: The practical difference between the Top-K and SA Top-K clustering mechanisms. Here, $S$ indicates the clustering direction of two surrogate tokens. The blue and green dashed circles indicate the clusters that the Top-K and SA Top-K clustering mechanisms would create, respectively.
  • Figure 3: Ablations on the cluster size using CAST with Top-K Clustering Mechanism (blue) and Single Assignment Top-K Clustering Mechanism (orange) on the Text and Image tasks of the LRA benchmark against (a & d) the performance, (b & e) the peak memory allocated, and (c & f) the time efficiency, respectively.
  • Figure 4: Visualizations of the learned clusters of a CAST model with SA Top-K on the LRA Image task. The number of clusters is 8. (a) An example image. (b- Left) Clustered pixels, where each color represents a different cluster. Example scores for clusters of $\textbf{A}_g$ (b- Middle & Right), each image corresponds to a different cluster for the first (b-Top) and last layer (b-Bottom), respectively.
  • Figure 5: A modularized sketch of the proposed method. Here, some details are omitted to make it easier to read. (a) shows intra-cluster self-attention, (b) shows the creation of the cluster summaries $R_{inter}$, and (c) shows how $R_{inter}$ and $R_{intra}$ are combined.
  • ...and 4 more figures