Table of Contents
Fetching ...

Chunk, Align, Select: A Simple Long-sequence Processing Method for Transformers

Jiawen Xie, Pengyu Cheng, Xiao Liang, Yong Dai, Nan Du

TL;DR

SimCAS introduces Chunk-Align-Select to scale transformers to long sequences by partitioning input into chunks, aligning inter-chunk semantics across encoder layers, and selectively aggregating token representations for decoding. The method leverages a PPO-based token selector that uses encoder outputs and decoder feedback to decide which hidden states to carry forward, achieving near-linear compute and memory growth with input length. Across seven long-context datasets spanning summarization and reading comprehension, SimCAS consistently outperforms strong baselines, including full-attention and sparse-attention methods, with notable gains on NarrativeQA and PubMed. The work demonstrates that transformers can effectively operate as environments for policy learning, using attention scores and generation likelihood to guide selective information retention, and highlights practical scalability and resource efficiency benefits for real-world long-document tasks.

Abstract

Although dominant in natural language processing, transformer-based models remain challenged by the task of long-sequence processing, because the computational cost of self-attention operations in transformers swells quadratically with the input sequence length. To alleviate the complexity of long-sequence processing, we propose a simple framework to enable the offthe-shelf pre-trained transformers to process much longer sequences, while the computation and memory costs remain growing linearly with the input sequence lengths. More specifically, our method divides each long-sequence input into a batch of chunks, then aligns the interchunk information during the encoding steps, and finally selects the most representative hidden states from the encoder for the decoding process. To extract inter-chunk semantic information, we align the start and end token embeddings among chunks in each encoding transformer block. To learn an effective hidden selection policy, we design a dual updating scheme inspired by reinforcement learning, which regards the decoders of transformers as environments, and the downstream performance metrics as the rewards to evaluate the hidden selection actions. Our empirical results on real-world long-text summarization and reading comprehension tasks demonstrate effective improvements compared to prior longsequence processing baselines.

Chunk, Align, Select: A Simple Long-sequence Processing Method for Transformers

TL;DR

SimCAS introduces Chunk-Align-Select to scale transformers to long sequences by partitioning input into chunks, aligning inter-chunk semantics across encoder layers, and selectively aggregating token representations for decoding. The method leverages a PPO-based token selector that uses encoder outputs and decoder feedback to decide which hidden states to carry forward, achieving near-linear compute and memory growth with input length. Across seven long-context datasets spanning summarization and reading comprehension, SimCAS consistently outperforms strong baselines, including full-attention and sparse-attention methods, with notable gains on NarrativeQA and PubMed. The work demonstrates that transformers can effectively operate as environments for policy learning, using attention scores and generation likelihood to guide selective information retention, and highlights practical scalability and resource efficiency benefits for real-world long-document tasks.

Abstract

Although dominant in natural language processing, transformer-based models remain challenged by the task of long-sequence processing, because the computational cost of self-attention operations in transformers swells quadratically with the input sequence length. To alleviate the complexity of long-sequence processing, we propose a simple framework to enable the offthe-shelf pre-trained transformers to process much longer sequences, while the computation and memory costs remain growing linearly with the input sequence lengths. More specifically, our method divides each long-sequence input into a batch of chunks, then aligns the interchunk information during the encoding steps, and finally selects the most representative hidden states from the encoder for the decoding process. To extract inter-chunk semantic information, we align the start and end token embeddings among chunks in each encoding transformer block. To learn an effective hidden selection policy, we design a dual updating scheme inspired by reinforcement learning, which regards the decoders of transformers as environments, and the downstream performance metrics as the rewards to evaluate the hidden selection actions. Our empirical results on real-world long-text summarization and reading comprehension tasks demonstrate effective improvements compared to prior longsequence processing baselines.
Paper Structure (42 sections, 12 equations, 9 figures, 11 tables, 1 algorithm)

This paper contains 42 sections, 12 equations, 9 figures, 11 tables, 1 algorithm.

Figures (9)

  • Figure 1: The SimCAS framework: The long inputs are first divided into a batch of chunks, each of which is filled with start token [S], padding token [P] and end token [E]. Then the inter-chunk information can be transferred via the alignment of [S] and [E] representations after each encoder layer. Next, hidden states are selected for decoding steps. The decoder output logits and attention scores are utilized as rewards for updating the token selector.
  • Figure 2: The left-hand-side plot shows the change of the actual number of tokens entered into the model and the number of tokens selected by our selector during training. The red dashed line represents the conditional boundary for the skipping reward. The right-hand-side plot shows the effect of increasing the number of input tokens in the inference phase on the time latency and the number of selected tokens. The area marked in blue on the right represents the limit of the number of tokens that the V100 can handle (350K tokens).
  • Figure 3: System performance comparison on NarrativeQA test set.
  • Figure 4: The AVG ROUGE scores (R-1, R-2, and R-L) of the pre-trained models (BART$_\text{base}$, LED$_\text{base}$, SimCAS$_\text{base}$) with 0, 10, and 100 training examples with variance. All results are obtained by the average of 5 random runs with different seeds.
  • Figure 5: Input text length distributions on the six summarization datasets
  • ...and 4 more figures