Table of Contents
Fetching ...

Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains

Ashok Vardhan Makkuva, Marco Bondaschi, Adway Girish, Alliot Nagle, Martin Jaggi, Hyeji Kim, Michael Gastpar

TL;DR

This work introduces a principled framework to analyze transformers through Markov chains, explaining why single-layer transformers can get stuck in unigram-like local minima while deeper models reliably reach the bigram/global kernel. It proves the existence of global minima and bad local minima conditioned on data switching and weight tying, and shows depth helps escape unfavorable optima. The results extend to multi-state Markov chains and reveal depth benefits even when data distributions would suggest traps, with empirical evidence aligning with theory. Open questions include higher-order Markov behavior and the precise thresholds for local minima in more complex architectures.

Abstract

Attention-based transformers have achieved tremendous success across a variety of disciplines including natural languages. To deepen our understanding of their sequential modeling capabilities, there is a growing interest in using Markov input processes to study them. A key finding is that when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution. In contrast, single-layer transformers, unable to form an induction head, directly learn the Markov kernel but often face a surprising challenge: they become trapped in local minima representing the unigram distribution, whereas deeper models reliably converge to the ground-truth bigram. While single-layer transformers can theoretically model first-order Markov chains, their empirical failure to learn this simple kernel in practice remains a curious phenomenon. To explain this contrasting behavior of single-layer models, in this paper we introduce a new framework for a principled analysis of transformers via Markov chains. Leveraging our framework, we theoretically characterize the loss landscape of single-layer transformers and show the existence of global minima (bigram) and bad local minima (unigram) contingent on data properties and model architecture. We precisely delineate the regimes under which these local optima occur. Backed by experiments, we demonstrate that our theoretical findings are in congruence with the empirical results. Finally, we outline several open problems in this arena. Code is available at https://github.com/Bond1995/Markov .

Attention with Markov: A Framework for Principled Analysis of Transformers via Markov Chains

TL;DR

This work introduces a principled framework to analyze transformers through Markov chains, explaining why single-layer transformers can get stuck in unigram-like local minima while deeper models reliably reach the bigram/global kernel. It proves the existence of global minima and bad local minima conditioned on data switching and weight tying, and shows depth helps escape unfavorable optima. The results extend to multi-state Markov chains and reveal depth benefits even when data distributions would suggest traps, with empirical evidence aligning with theory. Open questions include higher-order Markov behavior and the precise thresholds for local minima in more complex architectures.

Abstract

Attention-based transformers have achieved tremendous success across a variety of disciplines including natural languages. To deepen our understanding of their sequential modeling capabilities, there is a growing interest in using Markov input processes to study them. A key finding is that when trained on first-order Markov chains, transformers with two or more layers consistently develop an induction head mechanism to estimate the in-context bigram conditional distribution. In contrast, single-layer transformers, unable to form an induction head, directly learn the Markov kernel but often face a surprising challenge: they become trapped in local minima representing the unigram distribution, whereas deeper models reliably converge to the ground-truth bigram. While single-layer transformers can theoretically model first-order Markov chains, their empirical failure to learn this simple kernel in practice remains a curious phenomenon. To explain this contrasting behavior of single-layer models, in this paper we introduce a new framework for a principled analysis of transformers via Markov chains. Leveraging our framework, we theoretically characterize the loss landscape of single-layer transformers and show the existence of global minima (bigram) and bad local minima (unigram) contingent on data properties and model architecture. We precisely delineate the regimes under which these local optima occur. Backed by experiments, we demonstrate that our theoretical findings are in congruence with the empirical results. Finally, we outline several open problems in this arena. Code is available at https://github.com/Bond1995/Markov .
Paper Structure (27 sections, 5 theorems, 74 equations, 5 figures, 2 tables)

This paper contains 27 sections, 5 theorems, 74 equations, 5 figures, 2 tables.

Key Result

Theorem 1

Let the input sequence be $\{x_n\}_{n=1}^N \sim (\boldsymbol{\pi}(p,q), \boldsymbol{P}(p,q))$ for some fixed $(p,q) \in (0,1)^2$ and $\boldsymbol{\theta} \in \mathbb{R}^{D-d}$ be the transformer parameters for weight-tied case. Then for all $(p,q)$, there exists a $\boldsymbol{\theta}_\star \in \mat Further, $\boldsymbol{\theta}_\star$ satisfies: In addition, the same result holds for the non-wei

Figures (5)

  • Figure 1: Single-layer transformers get stuck at local minima, corresponding to the unigram model, when the input is a first-order Markov chain with switching probabilities $p=0.5$ and $q=0.8$ (fig:markov_model). However, deeper models escape to global minima corresponding to the bigram model.
  • Figure 2: Analysis of transformers via Markov chains.
  • Figure 3: Effect of weight tying on test loss and predicted probabilities $f_{\boldsymbol{\theta}}(x_1^{n_k})$ for zero indices $\{n_k\}_{k=1}^{100}$ such that $x_{n_k}=0$. For (a),(c): $p=0.5, q = 0.8$. With weight tying, the loss converges to a local minimum, and the predicted probability is $\pi_1 = p/(p+q)$. Without weight tying, we predict the correct probability $p$ and converge to a global minimum. For (b),(d): $p=0.2, q = 0.3$. The test loss always converges to a global minimum, and the predicted probability is $p$.
  • Figure 4: Average of predicted probabilities across 5 runs for different values of $p$ and $q$, with and without weight-tying. In the former case, there is a clear demarcation beteen the cases where $p+q < 1$ and those where $p+q>1$. For $p+q < 1$, all runs accurately predict the correct conditional probability. For $p+q>1$, some of the runs predict the stationary probability instead, causing the average to diverge from the correct $p$. In the latter case, the model always predicts the correct probability for all $p$ and $q$.
  • Figure 5: Effect of weight-tying for the multi-state Markov chain with $S=5$ states. (a) shows the symmetric multi-state Markov data model used, while (b), (c), (d) and (e) show the test loss for different values of $p$. Similar to the binary case, when $p$ is large enough, a $1$-layer transformer with weight-tying gets stuck in local minima, whereas it escapes to global minima without.

Theorems & Definitions (20)

  • Theorem 1: Global minimum
  • Remark 1
  • proof : Proof sketch
  • Theorem 2: Bad local minimum
  • Remark 2
  • proof : Proof sketch
  • Theorem 3: Saddle point
  • Lemma 1: Loss as KL divergence
  • Remark 3
  • proof
  • ...and 10 more