Table of Contents
Fetching ...

SWAN-GPT: An Efficient and Scalable Approach for Long-Context Language Modeling

Krishna C. Puvvada, Faisal Ladhak, Santiago Akle Serrano, Cheng-Ping Hsieh, Shantanu Acharya, Somshubra Majumdar, Fei Jia, Samuel Kriman, Simeng Sun, Dima Rekesh, Boris Ginsburg

TL;DR

SWAN-GPT tackles the long-context problem by showing that a decoder-only Transformer can robustly extrapolate to sequences far longer than its training length. It achieves this with a hybrid architecture that interleaves NoPE global attention and SWA-RoPE local attention, plus a dynamic, logarithmic scaling of attention logits to sustain performance at extreme lengths. The authors provide mechanistic evidence via position-probing and attention-pattern analyses, and demonstrate practical value by converting pre-trained RoPE models to SWAN with only continued pre-training. The results indicate strong long-context performance with competitive short-context benchmarks and a cost-efficient upgrade path for existing models.

Abstract

We present a decoder-only Transformer architecture that robustly generalizes to sequence lengths substantially longer than those seen during training. Our model, SWAN-GPT, interleaves layers without positional encodings (NoPE) and sliding-window attention layers equipped with rotary positional encodings (SWA-RoPE). Experiments demonstrate strong performance on sequence lengths significantly longer than the training length without the need for additional long-context training. This robust length extrapolation is achieved through our novel architecture, enhanced by a straightforward dynamic scaling of attention scores during inference. In addition, SWAN-GPT is more computationally efficient than standard GPT architectures, resulting in cheaper training and higher throughput. Further, we demonstrate that existing pre-trained decoder-only models can be efficiently converted to the SWAN architecture with minimal continued training, enabling longer contexts. Overall, our work presents an effective approach for scaling language models to longer contexts in a robust and efficient manner.

SWAN-GPT: An Efficient and Scalable Approach for Long-Context Language Modeling

TL;DR

SWAN-GPT tackles the long-context problem by showing that a decoder-only Transformer can robustly extrapolate to sequences far longer than its training length. It achieves this with a hybrid architecture that interleaves NoPE global attention and SWA-RoPE local attention, plus a dynamic, logarithmic scaling of attention logits to sustain performance at extreme lengths. The authors provide mechanistic evidence via position-probing and attention-pattern analyses, and demonstrate practical value by converting pre-trained RoPE models to SWAN with only continued pre-training. The results indicate strong long-context performance with competitive short-context benchmarks and a cost-efficient upgrade path for existing models.

Abstract

We present a decoder-only Transformer architecture that robustly generalizes to sequence lengths substantially longer than those seen during training. Our model, SWAN-GPT, interleaves layers without positional encodings (NoPE) and sliding-window attention layers equipped with rotary positional encodings (SWA-RoPE). Experiments demonstrate strong performance on sequence lengths significantly longer than the training length without the need for additional long-context training. This robust length extrapolation is achieved through our novel architecture, enhanced by a straightforward dynamic scaling of attention scores during inference. In addition, SWAN-GPT is more computationally efficient than standard GPT architectures, resulting in cheaper training and higher throughput. Further, we demonstrate that existing pre-trained decoder-only models can be efficiently converted to the SWAN architecture with minimal continued training, enabling longer contexts. Overall, our work presents an effective approach for scaling language models to longer contexts in a robust and efficient manner.

Paper Structure

This paper contains 12 sections, 6 figures, 6 tables.

Figures (6)

  • Figure 1: Mean negative log likelihood by token position for GPT with rotary positional encodings (RoPE GPT, blue), a GPT with no positional encodings (NoPE, orange), a Swan model (red), and a model conposed only of sliding window attention layers (SWA, green). Both RoPE GPT and NoPE models struggle beyond training sequence length (1024). SWA model doesn't experience such catastrophic failure due to its limited context. Swan model behaves like a SWA model without the limitation of SWA model due to its global NoPE layers.
  • Figure 2: Predictions of token indices by 8 different probes. Each probe is trained with tokens from one model and different context regions (demarcated by dashed lines). Probes on NoPE models (blue) extrapolate correctly up until the maximum NoPE training length (solid line). Probes on SWAN (red) are not predictive of token indices.
  • Figure 3: Attention maps for 6th layer of NoPE model. Averaged over all heads and all validation records (left). Cross section for sequence of length 512, 1024 (limit of model training range) and 1536 in length extrapolation regime. Attention pattern of leading 256 tokens differ for sequences within and beyond training range.
  • Figure 4: Attention maps for 20th layer of our SWAN model (6th NoPE layer). Averaged over all heads and all validation records (left). Cross section for sequence of length 512, 1024 (limit of model training range) and 1536 in length extrapolation regime. Attention pattern of leading 256 tokens show consistent decay patterns for sequences with length within and beyond training range.
  • Figure 5: Estimates of optimal scaling factors (black) comparing the fit of our logarithmic scaling function vs. YaRN scaling. We find that YaRN scaling doesn't work as well for NoPE layers.
  • ...and 1 more figures