Table of Contents
Fetching ...

Provable optimal transport with transformers: The essence of depth and prompt engineering

Hadi Daneshmand

TL;DR

The paper links token alignment in transformers to discrete optimal transport (OT), providing a mechanistic interpretation in which softmax self-attention effectively performs gradient-descent steps on the OT dual objective. It proves depth controls OT approximation accuracy with a constructive bound, and shows that deep, pre-trained transformers can solve OT and even sorting tasks without retraining, aided by engineered prompts that extend memory. Empirical results in English–French translation and embedding-based OT corroborate that attention progressively aligns semantically related word pairs and that prompt design dramatically enhances in-context computation. Together, these results offer both theoretical insight into transformer dynamics and practical guidance for prompting strategies in alignment and order-related tasks.

Abstract

Despite their empirical success, the internal mechanism by which transformer models align tokens during language processing remains poorly understood. This paper provides a mechanistic and theoretical explanation of token alignment in LLMs. We first present empirical evidences showing that, in machine translation, attention weights progressively align translated word pairs across layers, closely approximating Optimal Transport (OT) between word embeddings. Building on this observation, we prove that softmax self-attention layers can simulate gradient descent on the dual of the entropy-regularized OT problem, providing a theoretical foundation for the alignment. Our analysis yields a constructive convergence bound showing that transformer depth controls OT approximation accuracy. A direct implication is that standard transformers can sort lists of varying lengths without any parameter adjustment, up to an error term vanishing with transformers depth.

Provable optimal transport with transformers: The essence of depth and prompt engineering

TL;DR

The paper links token alignment in transformers to discrete optimal transport (OT), providing a mechanistic interpretation in which softmax self-attention effectively performs gradient-descent steps on the OT dual objective. It proves depth controls OT approximation accuracy with a constructive bound, and shows that deep, pre-trained transformers can solve OT and even sorting tasks without retraining, aided by engineered prompts that extend memory. Empirical results in English–French translation and embedding-based OT corroborate that attention progressively aligns semantically related word pairs and that prompt design dramatically enhances in-context computation. Together, these results offer both theoretical insight into transformer dynamics and practical guidance for prompting strategies in alignment and order-related tasks.

Abstract

Despite their empirical success, the internal mechanism by which transformer models align tokens during language processing remains poorly understood. This paper provides a mechanistic and theoretical explanation of token alignment in LLMs. We first present empirical evidences showing that, in machine translation, attention weights progressively align translated word pairs across layers, closely approximating Optimal Transport (OT) between word embeddings. Building on this observation, we prove that softmax self-attention layers can simulate gradient descent on the dual of the entropy-regularized OT problem, providing a theoretical foundation for the alignment. Our analysis yields a constructive convergence bound showing that transformer depth controls OT approximation accuracy. A direct implication is that standard transformers can sort lists of varying lengths without any parameter adjustment, up to an error term vanishing with transformers depth.

Paper Structure

This paper contains 43 sections, 4 theorems, 74 equations, 9 figures, 2 tables.

Key Result

Theorem D.1

Let $C \in \mathbb{R}^{n \times n}$ be a cost matrix with entries $C_{ij}$, and fix $\lambda>0$. Define Let $P_\lambda^* \in \mathbb{R}^{n\times n}_{+}$ denote the entropy-regularized optimal transport plan, which admits the factorization unique up to positive rescaling of $p^*,q^*$. Consider a transformer with a fixed choice of parameters (independent of $n$, $d$, and the input) as constructed

Figures (9)

  • Figure 1: Translation with OT. The rightmost plot shows the optimal transport (OT) solution from \ref{['eq:optimal_transport']}, computed between the English word embeddings ($x_1, \dots, x_n$) and the French word embeddings ($y_1, \dots, y_n$). Red dots mark correctly aligned translation pairs. Observe the OT solution matches words with the same meaning. The other plots depict attention-weight heatmaps from layers 1, 6, and 12, showing how the model iteratively approximates the OT solution.
  • Figure 2: Observations on In-context Learning for OT.(1) The model is trained to solve OT with 7 data points and evaluated on 9 data points. The left image shows the attention weights, which closely approximate the OT solution shown on the right. (2) After specific prompt engineering, the attention weights between tokens estimate the OT solution. Notably, this prompt engineering is used in (1). (3) The attention weights evolve across layers, progressively yielding a more accurate approximation of the optimal solution. See Appendix \ref{['sec:experiments_app']} for details.
  • Figure 3: Convergence of attention matrices. The plotted matrices are attention weights in layers (1), (300) and (600). We observe these matrices converge to the regularized OT solution (the rightmost plot), which is proven by Theorem \ref{['thm:convergence']}.
  • Figure 4: Sample Size. left: $n=8$, right: $n=4$. The transformer's weights remain unchanged. Observe the transformer can solve the OT problem for both values of $n$, demonstrating a form of out-of-distribution generalization proven in Thm. \ref{['thm:gd']}.
  • Figure 5: Attention dynamics for translation (En-Fr). x-axis: transformer layer index $\ell$. y-axis: The metrics defined in \ref{['eq:mrr']} and \ref{['eq:hits']}; Increasing metric values indicate that attention weights across layers progressively provide better estimates of translated word alignments, closely resembling the attention dynamics for OT.
  • ...and 4 more figures

Theorems & Definitions (6)

  • Theorem D.1: Formal version of Thm. 3.2: Convergence to $P_\lambda^*$ in the Hilbert projective metric
  • Remark D.2
  • Remark D.3: On scaling, metric choice, and non-monotonicity
  • Proposition D.4
  • Proposition D.5
  • Lemma D.6