Table of Contents
Fetching ...

Optimizing Attention with Mirror Descent: Generalized Max-Margin Token Selection

Addison Kristanto Julistiono, Davoud Ataee Tarzanagh, Navid Azizan

TL;DR

The paper investigates optimizing softmax attention with a generalized mirror-descent framework using ell_p norms, showing that MD converges in direction to a generalized hard-margin SVM that separates locally optimal tokens. It extends prior gradient-descent analyses by deriving directional convergence results and a poly-logarithmic rate, and it establishes joint convergence for the key-query matrix W and decoder v under regularization-path assumptions. Empirically, MD improves generalization and token selection on synthetic and real data, yielding sparser attention weights and competitive performance relative to Adam. The findings illuminate the implicit bias of MD in attention models and point to practical benefits in interpretability and efficiency, with future work extending to multi-head and multi-layer transformers.

Abstract

Attention mechanisms have revolutionized several domains of artificial intelligence, such as natural language processing and computer vision, by enabling models to selectively focus on relevant parts of the input data. While recent work has characterized the optimization dynamics of gradient descent (GD) in attention-based models and the structural properties of its preferred solutions, less is known about more general optimization algorithms such as mirror descent (MD). In this paper, we investigate the convergence properties and implicit biases of a family of MD algorithms tailored for softmax attention mechanisms, with the potential function chosen as the $p$-th power of the $\ell_p$-norm. Specifically, we show that these algorithms converge in direction to a generalized hard-margin SVM with an $\ell_p$-norm objective when applied to a classification problem using a softmax attention model. Notably, our theoretical results reveal that the convergence rate is comparable to that of traditional GD in simpler models, despite the highly nonlinear and nonconvex nature of the present problem. Additionally, we delve into the joint optimization dynamics of the key-query matrix and the decoder, establishing conditions under which this complex joint optimization converges to their respective hard-margin SVM solutions. Lastly, our numerical experiments on real data demonstrate that MD algorithms improve generalization over standard GD and excel in optimal token selection.

Optimizing Attention with Mirror Descent: Generalized Max-Margin Token Selection

TL;DR

The paper investigates optimizing softmax attention with a generalized mirror-descent framework using ell_p norms, showing that MD converges in direction to a generalized hard-margin SVM that separates locally optimal tokens. It extends prior gradient-descent analyses by deriving directional convergence results and a poly-logarithmic rate, and it establishes joint convergence for the key-query matrix W and decoder v under regularization-path assumptions. Empirically, MD improves generalization and token selection on synthetic and real data, yielding sparser attention weights and competitive performance relative to Adam. The findings illuminate the implicit bias of MD in attention models and point to practical benefits in interpretability and efficiency, with future work extending to multi-head and multi-layer transformers.

Abstract

Attention mechanisms have revolutionized several domains of artificial intelligence, such as natural language processing and computer vision, by enabling models to selectively focus on relevant parts of the input data. While recent work has characterized the optimization dynamics of gradient descent (GD) in attention-based models and the structural properties of its preferred solutions, less is known about more general optimization algorithms such as mirror descent (MD). In this paper, we investigate the convergence properties and implicit biases of a family of MD algorithms tailored for softmax attention mechanisms, with the potential function chosen as the -th power of the -norm. Specifically, we show that these algorithms converge in direction to a generalized hard-margin SVM with an -norm objective when applied to a classification problem using a softmax attention model. Notably, our theoretical results reveal that the convergence rate is comparable to that of traditional GD in simpler models, despite the highly nonlinear and nonconvex nature of the present problem. Additionally, we delve into the joint optimization dynamics of the key-query matrix and the decoder, establishing conditions under which this complex joint optimization converges to their respective hard-margin SVM solutions. Lastly, our numerical experiments on real data demonstrate that MD algorithms improve generalization over standard GD and excel in optimal token selection.

Paper Structure

This paper contains 36 sections, 25 theorems, 202 equations, 17 figures, 1 table.

Key Result

Lemma 4

Let Assumption assumption-loss hold. Consider the sequence $\{W(k)\}_{k \geq 0}$ generated by Algorithm alg:p:agd with stepsize $\eta > 0$. Then, the increment of the $\ell_{p,p}$-norm between consecutive iterations can be bounded as follows:

Figures (17)

  • Figure 1: Visualization of \ref{['eqn:w-svm']} for $p=3$.
  • Figure 2: Illustration of Lemma \ref{['lemma-stay-in-cone']}. $W(k)$, $\forall k>0$ are within the larger set.
  • Figure 3: Effect of token selection on margin size in \ref{['eqn:vp:svm']} for Example \ref{['exm:svm_token_visualization']}. The first plot shows the largest class margin with optimal tokens $X_{11}$ and $X_{21}$. In subsequent plots, as different tokens are used, the class margin (light blue shaded area) decreases, reflecting suboptimal class separation.
  • Figure 4: Average directional Bregman divergence between the (a) $\ell_{1.75}$, (b) $\ell_2$, and (c) $\ell_3$ optimization paths and the \ref{['eqn:w-svm']} solutions for $p=1.75, 2$, and $3$ at each training iteration from 100 trials. The shaded area represents the standard deviation of the directional Bregman divergence.
  • Figure 5: Direction of change of two entries of $W$ updated by \ref{['alg:p:agd']} with $p=1.75$, $p=2$, and $p=3$ for one trial, shown in (a), (b), and (c). Each axis represents a different entry. The orange line shows the direction of \ref{['eqn:w-svm']}.
  • ...and 12 more figures

Theorems & Definitions (34)

  • Definition 1: Bregman Divergence
  • Definition 2: Token Score
  • Definition 3: Globally and Locally Optimal Tokens
  • Lemma 4: $\ell_{p,p}$-Growth Bound of Attention Weights
  • Definition 5: Attention SVM with $\ell_p$--norm Objective
  • Example 1
  • Theorem 6: $\ell_p$--norm Regularization Path
  • Definition 7
  • Theorem 8
  • Remark 9
  • ...and 24 more