Table of Contents
Fetching ...

Incremental Learning of Sparse Attention Patterns in Transformers

Oğuz Kaan Yüksel, Rodrigo Alvarez Lucendo, Nicolas Flammarion

TL;DR

This work investigates how transformers learn to compose information across multiple past positions using sparse attention, framing the problem as a high-order Markov chain task. It shows that learning unfolds in stage-like phases where heads first capture the most statisti- cally salient patterns and later specialize cooperatively on additional patterns, a dynamic captured by simplified gradient-flow equations linked to tensor-factorization. The authors provide convergence results for the competitive and cooperative phases and demonstrate how early stopping induces a beneficial misspecification regularization, with implications for generalization in language and reasoning tasks. Through a regression-variant analysis and a minimal architectural setting, the paper elucidates the mechanisms by which sparse attention patterns emerge and coordinate to solve complex sequential tasks. Overall, the results offer a theoretical foundation for staged learning in transformers and its impact on generalization and sample efficiency in data-constrained regimes.

Abstract

This paper introduces a high-order Markov chain task to investigate how transformers learn to integrate information from multiple past positions with varying statistical significance. We demonstrate that transformers learn this task incrementally: each stage is defined by the acquisition of specific information through sparse attention patterns. Notably, we identify a shift in learning dynamics from competitive, where heads converge on the most statistically dominant pattern, to cooperative, where heads specialize in distinct patterns. We model these dynamics using simplified differential equations that characterize the trajectory and prove stage-wise convergence results. Our analysis reveals that transformers ascend a complexity ladder by passing through simpler, misspecified hypothesis classes before reaching the full model class. We further show that early stopping acts as an implicit regularizer, biasing the model toward these simpler classes. These results provide a theoretical foundation for the emergence of staged learning and complex behaviors in transformers, offering insights into generalization for natural language processing and algorithmic reasoning.

Incremental Learning of Sparse Attention Patterns in Transformers

TL;DR

This work investigates how transformers learn to compose information across multiple past positions using sparse attention, framing the problem as a high-order Markov chain task. It shows that learning unfolds in stage-like phases where heads first capture the most statisti- cally salient patterns and later specialize cooperatively on additional patterns, a dynamic captured by simplified gradient-flow equations linked to tensor-factorization. The authors provide convergence results for the competitive and cooperative phases and demonstrate how early stopping induces a beneficial misspecification regularization, with implications for generalization in language and reasoning tasks. Through a regression-variant analysis and a minimal architectural setting, the paper elucidates the mechanisms by which sparse attention patterns emerge and coordinate to solve complex sequential tasks. Overall, the results offer a theoretical foundation for staged learning in transformers and its impact on generalization and sample efficiency in data-constrained regimes.

Abstract

This paper introduces a high-order Markov chain task to investigate how transformers learn to integrate information from multiple past positions with varying statistical significance. We demonstrate that transformers learn this task incrementally: each stage is defined by the acquisition of specific information through sparse attention patterns. Notably, we identify a shift in learning dynamics from competitive, where heads converge on the most statistically dominant pattern, to cooperative, where heads specialize in distinct patterns. We model these dynamics using simplified differential equations that characterize the trajectory and prove stage-wise convergence results. Our analysis reveals that transformers ascend a complexity ladder by passing through simpler, misspecified hypothesis classes before reaching the full model class. We further show that early stopping acts as an implicit regularizer, biasing the model toward these simpler classes. These results provide a theoretical foundation for the emergence of staged learning and complex behaviors in transformers, offering insights into generalization for natural language processing and algorithmic reasoning.
Paper Structure (36 sections, 15 theorems, 120 equations, 19 figures, 4 tables)

This paper contains 36 sections, 15 theorems, 120 equations, 19 figures, 4 tables.

Key Result

Proposition 1

The gradient flow dynamics of the loss in eq:loss is equivalent to that on

Figures (19)

  • Figure 1: (Top left) The task is based on a high-order Markov chain, where the next token depends on multiple past tokens with different importance weights. The context is divided into different groups of positions, each aggregated and processed by an associated feature matrix $A_k^\star$ of various importance which is represented by the size of the feature matrix. (Top right) An idealized representation of the task in a multi-head single-layer attention. Each head represents an individual sparse attention pattern required to solve the task. (Bottom left) Transformers learn the task incrementally, with each stage corresponding to the acquisition of a sparse attention pattern which is indicated by the KL divergence between predictors $A_{1:i}^\star$ that only depends a subset of relevant positions as defined in \ref{['eq:ground_truth']} and the transformer. (Bottom right) The learning dynamics transition from competitive, where all heads focus on the statistically most important pattern (indicated by high combined attention on the main diagonal), to cooperative, where different heads specialize in different patterns.
  • Figure 2: The sum of learned attention patterns for $h=3, w=12$ at different stages of training where blue, yellow and green colors correspond to different heads. At $t=0$, the attention is uniform as the model is randomly initialized. At $t=60$, all heads learn from the positions in $I(1)$, indicated by the overlapping blue, yellow and green colors, with one head focusing on the positions in $I(2)$ with a small attention. At $t=300$, a head learns from the positions in $I(2)$ whereas two heads still focus on $I(1)$. At $t=1000$, the model finally learns to integrate all positions where each head specializes in a different pattern. The main diagonal does not have the same intensity as the other positions as it is learned via the skip connection directly from the input.
  • Figure 3: (Left) KL divergence between the ground truths that only depend on the positions in $I(1)$, $I(1) \cup I(2)$ and $I(1) \cup I(2) \cup I(3)$, and the predictions of the transformer with unrestricted context length. (Right) KL divergence between the predictions of the transformers with restricted context lengths $c = 4, 8, 12$ and the transformer without any context length restriction. The transformers learn the task incrementally, with each stage corresponding to the acquisition of information from a subset of positions.
  • Figure 4: (Left) Excess loss of the minimal architecture with different initialization scales. (Right) Excess loss of the minimal architecture with different multiplicative constants $m$ that determine the importance hierarchy.
  • Figure 5: The impact of the dataset size on the incremental learning behavior. (Left) The best validation loss as a function of the dataset size. (Right) The KL divergence between the predictions of the model with different context lengths and the trained transformer. Dashed lines indicate the first step that obtains the best excess loss.
  • ...and 14 more figures

Theorems & Definitions (26)

  • Proposition 1
  • Theorem 1
  • Theorem 2
  • Remark 1
  • Theorem 3
  • Lemma 1
  • Theorem 4
  • Proposition 1
  • proof
  • Lemma 2
  • ...and 16 more