Table of Contents
Fetching ...

Cluster-Former: Clustering-based Sparse Transformer for Long-Range Dependency Encoding

Shuohang Wang, Luowei Zhou, Zhe Gan, Yen-Chun Chen, Yuwei Fang, Siqi Sun, Yu Cheng, Jingjing Liu

TL;DR

Long-range dependency modeling with Transformers is hampered by quadratic memory and compute costs on long sequences. The authors propose Cluster-Former, a two-tier Transformer framework combining a Sliding-Window Layer for local encoding with a Cluster-Former Layer that encodes global context by clustering hidden states using a memory bank and periodically updating $p$ cluster centroids via $K$-Means. The approach yields state-of-the-art results on QA benchmarks such as Natural Questions long answer, SearchQA, and Quasar-T, with ablations showing the importance of cluster counts and middle-layer placement. This work demonstrates that integrating local and global sparse attention through learnable clustering is effective for long-context QA and may extend to other long-range NLP tasks.

Abstract

Transformer has become ubiquitous in the deep learning field. One of the key ingredients that destined its success is the self-attention mechanism, which allows fully-connected contextual encoding over input tokens. However, despite its effectiveness in modeling short sequences, self-attention suffers when handling inputs with extreme long-range dependencies, as its complexity grows quadratically with respect to the sequence length. Therefore, long sequences are often encoded by Transformer in chunks using a sliding window. In this paper, we propose Cluster-Former, a novel clustering-based sparse Transformer to perform attention across chunked sequences. The proposed framework is pivoted on two unique types of Transformer layer: Sliding-Window Layer and Cluster-Former Layer, which encode local sequence information and global context jointly and iteratively. This new design allows information integration beyond local windows, which is especially beneficial for question answering (QA) tasks that rely on long-range dependencies. Experiments show that Cluster-Former achieves state-of-the-art performance on several major QA benchmarks.

Cluster-Former: Clustering-based Sparse Transformer for Long-Range Dependency Encoding

TL;DR

Long-range dependency modeling with Transformers is hampered by quadratic memory and compute costs on long sequences. The authors propose Cluster-Former, a two-tier Transformer framework combining a Sliding-Window Layer for local encoding with a Cluster-Former Layer that encodes global context by clustering hidden states using a memory bank and periodically updating cluster centroids via -Means. The approach yields state-of-the-art results on QA benchmarks such as Natural Questions long answer, SearchQA, and Quasar-T, with ablations showing the importance of cluster counts and middle-layer placement. This work demonstrates that integrating local and global sparse attention through learnable clustering is effective for long-context QA and may extend to other long-range NLP tasks.

Abstract

Transformer has become ubiquitous in the deep learning field. One of the key ingredients that destined its success is the self-attention mechanism, which allows fully-connected contextual encoding over input tokens. However, despite its effectiveness in modeling short sequences, self-attention suffers when handling inputs with extreme long-range dependencies, as its complexity grows quadratically with respect to the sequence length. Therefore, long sequences are often encoded by Transformer in chunks using a sliding window. In this paper, we propose Cluster-Former, a novel clustering-based sparse Transformer to perform attention across chunked sequences. The proposed framework is pivoted on two unique types of Transformer layer: Sliding-Window Layer and Cluster-Former Layer, which encode local sequence information and global context jointly and iteratively. This new design allows information integration beyond local windows, which is especially beneficial for question answering (QA) tasks that rely on long-range dependencies. Experiments show that Cluster-Former achieves state-of-the-art performance on several major QA benchmarks.

Paper Structure

This paper contains 19 sections, 9 equations, 2 figures, 6 tables, 1 algorithm.

Figures (2)

  • Figure 1: Illustration of different methods for processing long sequences. Each square represents a hidden state. The black-dotted boxes are Transformer layers. (a) is the sliding-window-based method to chunk a long sequence into short ones with window size 3 and stride 2. (b) builds cross-sequence attention based on sliding window over pre-selected positions (red-dotted boxes). (c) hashes the hidden states into different buckets by randomly-initialized vectors. (d) is our proposed approach to cluster the hidden states. Our final model is a combination of (a) and (d) that processes both local and global context.
  • Figure 2: An overview of the proposed Transformer layer. (a) Sliding-Window layer over a sequence. (b) Cluster-Former layer over clustered hidden states from the output of (a). Cluster centroids are periodically updated based on the memory bank of the hidden states in the corresponding layer.