Table of Contents
Fetching ...

Short-Range Dependency Effects on Transformer Instability and a Decomposed Attention Solution

Suvadeep Hajra

TL;DR

This work identifies training instability in transformer models as stemming from self-attention’s difficulty in modeling dense local dependencies, which leads to logit explosion in the pre-softmax layer. It introduces Long Short-attention (LS-attention), a decomposed attention mechanism that merges short-range local heads with a small number of long-range global heads, achieving improved training stability and inference efficiency. Empirical results on long-sequence language modeling show LS-attention reduces perplexity faster and with far less GPU time than two stabilization baselines, and provides substantial latency savings over state-of-the-art MHSA implementations on long sequences. The approach has practical implications for stable, scalable pretraining of autoregressive transformers on long inputs, with potential applicability to tasks requiring robust local and global dependency modeling.

Abstract

Transformer language models have driven significant progress across various fields, including natural language processing and computer vision. A central component of these models is the self-attention (SA) mechanism, which learns rich vector representations of tokens by modeling their relationships with others in a sequence. However, despite extensive research, transformers continue to suffer from training instability -- often manifesting as spikes or divergence in the training loss during a run. In this work, we identify one source of this instability: SA's limited ability to capture short-range dependencies, especially in tasks like language modeling, where almost every token heavily relies on its nearby neighbors. This limitation causes the pre-softmax logits of SA to grow rapidly, destabilizing training. To address this, we propose decomposing the SA into local (short-range) and global (long-range) attention heads. This decomposed attention, referred to as Long Short-attention (LS-attention), mitigates logit explosion and results in more stable training compared to an equivalent multi-head self-attention (MHSA). Empirical comparisons with two alternative training stabilization methods show that LS-attention reduces the validation perplexity to nearly 2/5 of that achieved by one method and reaches a similar perplexity as the other method using only 1/20 of the GPU hours. Additionally, our experiments demonstrate that LS-attention reduces inference latency by up to 36% compared to a state-of-the-art implementation of equivalent MHSA.

Short-Range Dependency Effects on Transformer Instability and a Decomposed Attention Solution

TL;DR

This work identifies training instability in transformer models as stemming from self-attention’s difficulty in modeling dense local dependencies, which leads to logit explosion in the pre-softmax layer. It introduces Long Short-attention (LS-attention), a decomposed attention mechanism that merges short-range local heads with a small number of long-range global heads, achieving improved training stability and inference efficiency. Empirical results on long-sequence language modeling show LS-attention reduces perplexity faster and with far less GPU time than two stabilization baselines, and provides substantial latency savings over state-of-the-art MHSA implementations on long sequences. The approach has practical implications for stable, scalable pretraining of autoregressive transformers on long inputs, with potential applicability to tasks requiring robust local and global dependency modeling.

Abstract

Transformer language models have driven significant progress across various fields, including natural language processing and computer vision. A central component of these models is the self-attention (SA) mechanism, which learns rich vector representations of tokens by modeling their relationships with others in a sequence. However, despite extensive research, transformers continue to suffer from training instability -- often manifesting as spikes or divergence in the training loss during a run. In this work, we identify one source of this instability: SA's limited ability to capture short-range dependencies, especially in tasks like language modeling, where almost every token heavily relies on its nearby neighbors. This limitation causes the pre-softmax logits of SA to grow rapidly, destabilizing training. To address this, we propose decomposing the SA into local (short-range) and global (long-range) attention heads. This decomposed attention, referred to as Long Short-attention (LS-attention), mitigates logit explosion and results in more stable training compared to an equivalent multi-head self-attention (MHSA). Empirical comparisons with two alternative training stabilization methods show that LS-attention reduces the validation perplexity to nearly 2/5 of that achieved by one method and reaches a similar perplexity as the other method using only 1/20 of the GPU hours. Additionally, our experiments demonstrate that LS-attention reduces inference latency by up to 36% compared to a state-of-the-art implementation of equivalent MHSA.

Paper Structure

This paper contains 25 sections, 15 equations, 6 figures, 1 table.

Figures (6)

  • Figure 1: Mitigation of training instability and logit explosion using LS-attention. The upper plots show that the training loss of an autoregressive transformer model with Flash-attention begins to diverge after some training steps, whereas the same model with LS-attention remains stable. The bottom plots compare the maximum absolute pre-softmax logits of vanilla MHSA and LS-attention during training. LS-attention prevents logit explosion by reducing the maximum logit magnitude to less than one-twentieth that of vanilla MHSA.
  • Figure 2: Comparison of representing dense local dependencies by local and global attention. (a) Global attention attempts to represent $\mathcal{O}(n^2)$ attention scores (shown in blue) using only $\mathcal{O}(nd)$ degrees of freedom. (b) Local attention focuses on $\mathcal{O}(nl')$ attention scores, where $l' \ll n$, making it a better fit for the available $\mathcal{O}(nd)$ capacity. (c) In a synthetic dense local dependency learning task, local attention achieves lower training loss. (d) Local attention is more resilient to logit explosion.
  • Figure 3: Training instability and logit explosion in Flash-attention at longer sequence lengths.
  • Figure 4: Mitigation of logit explotion and training instability using LS-attention.
  • Figure 5: Performance comparison of LS-attention (in mixed BF16) with two alternatives: (1) Flash-attention trained with full FP32 precision, and (2) Flash-attention with QK-normalization (in mixed BF16). Sequence length $n$ is set to $8192$.
  • ...and 1 more figures