Table of Contents
Fetching ...

Zebra: Extending Context Window with Layerwise Grouped Local-Global Attention

Kaiqiang Song, Xiaoyang Wang, Sangwoo Cho, Xiaoman Pan, Dong Yu

TL;DR

Transformer-based models struggle with quadratic attention costs when processing long sequences. Zebra introduces layerwise grouped local-global attention and rotary positional embeddings to extend context windows efficiently, balancing local and global computations. Through Long Context Adaptation Training (LCAT) and Long Instruction Tuning (LIT) on Llama-2 and related baselines, Zebra achieves comparable short-context performance and improved long-context perplexity and efficiency, including faster training and inference. This work demonstrates a practical path to scalable, long-context LLMs applicable to long-form understanding and generation tasks.

Abstract

This paper introduces a novel approach to enhance the capabilities of Large Language Models (LLMs) in processing and understanding extensive text sequences, a critical aspect in applications requiring deep comprehension and synthesis of large volumes of information. Recognizing the inherent challenges in extending the context window for LLMs, primarily built on Transformer architecture, we propose a new model architecture, referred to as Zebra. This architecture efficiently manages the quadratic time and memory complexity issues associated with full attention in the Transformer by employing grouped local-global attention layers. Our model, akin to a zebra's alternating stripes, balances local and global attention layers, significantly reducing computational requirements and memory consumption. Comprehensive experiments, including pretraining from scratch, continuation of long context adaptation training, and long instruction tuning, are conducted to evaluate the Zebra's performance. The results show that Zebra achieves comparable or superior performance on both short and long sequence benchmarks, while also enhancing training and inference efficiency.

Zebra: Extending Context Window with Layerwise Grouped Local-Global Attention

TL;DR

Transformer-based models struggle with quadratic attention costs when processing long sequences. Zebra introduces layerwise grouped local-global attention and rotary positional embeddings to extend context windows efficiently, balancing local and global computations. Through Long Context Adaptation Training (LCAT) and Long Instruction Tuning (LIT) on Llama-2 and related baselines, Zebra achieves comparable short-context performance and improved long-context perplexity and efficiency, including faster training and inference. This work demonstrates a practical path to scalable, long-context LLMs applicable to long-form understanding and generation tasks.

Abstract

This paper introduces a novel approach to enhance the capabilities of Large Language Models (LLMs) in processing and understanding extensive text sequences, a critical aspect in applications requiring deep comprehension and synthesis of large volumes of information. Recognizing the inherent challenges in extending the context window for LLMs, primarily built on Transformer architecture, we propose a new model architecture, referred to as Zebra. This architecture efficiently manages the quadratic time and memory complexity issues associated with full attention in the Transformer by employing grouped local-global attention layers. Our model, akin to a zebra's alternating stripes, balances local and global attention layers, significantly reducing computational requirements and memory consumption. Comprehensive experiments, including pretraining from scratch, continuation of long context adaptation training, and long instruction tuning, are conducted to evaluate the Zebra's performance. The results show that Zebra achieves comparable or superior performance on both short and long sequence benchmarks, while also enhancing training and inference efficiency.
Paper Structure (29 sections, 13 equations, 6 figures, 9 tables, 1 algorithm)

This paper contains 29 sections, 13 equations, 6 figures, 9 tables, 1 algorithm.

Figures (6)

  • Figure 1: Four different attention strategies to be compared in this work. (a) Global Attention, where each token has its attention to all previous tokens and itself; (b) Local Attention, where each token only has the attention within its local window; (c) Local Attention with Global Approximation is newly introduced in this work, where each token not only has attention to its local window but also has an approximated attention from the remaining non-local chunks; (d) Group Attention is our introduced layerwise grouped local-global attention strategy, where we group $L$ layers and apply the global attention at the first layer of each group (the remaining layers use local attention).
  • Figure 2: The testing PPL gap between each method and the baseline system (global attention) on 1024, 4096, and 16384 training sequence length. The smaller the better. In this experiment, we split the entire testing set into different splits according to their length. Each split contains the instances within the length range of $\frac{x}{2}+1$ to $x$, except the first one (length $\leq 128$).
  • Figure 3: The validation PPL vs TFLOPs for global attention(red) and group attention(blue) on 1024, 4096, and 16384 training sequence lengths.
  • Figure 4: Perplexity on test sequences with 1k, 4k, and 16k training sequence lengths. In this experiment, we split the entire testing set into different splits according to their length. Each split contains the instances within the length range of $\frac{x}{2}+1$ to $x$, except the first one (length $\leq 128$).
  • Figure 5: Sequence Length vs. Number of Instances on The Pile Dataset.
  • ...and 1 more figures