Table of Contents
Fetching ...

Implicit Optimization Bias of Next-Token Prediction in Linear Models

Christos Thrampoulidis

TL;DR

Previous research on implicit bias in one-hot classification in one-hot classification is extended to the NTP setting, highlighting key differences and prompting further research into the optimization and generalization properties of NTP, irrespective of the specific architecture used to generate the context embeddings.

Abstract

We initiate an investigation into the optimization properties of next-token prediction (NTP), the dominant training paradigm for modern language models. Specifically, we study the structural properties of the solutions selected by gradient-based optimizers among the many possible minimizers of the NTP objective. By framing NTP as cross-entropy minimization across distinct contexts, each tied with a sparse conditional probability distribution across a finite vocabulary of tokens, we introduce "NTP-separability conditions" that enable reaching the data-entropy lower bound. With this setup, and focusing on linear models with fixed context embeddings, we characterize the optimization bias of gradient descent (GD): Within the data subspace defined by the sparsity patterns of distinct contexts, GD selects parameters that equate the logits' differences of in-support tokens to their log-odds. In the orthogonal subspace, the GD parameters diverge in norm and select the direction that maximizes a margin specific to NTP. These findings extend previous research on implicit bias in one-hot classification to the NTP setting, highlighting key differences and prompting further research into the optimization and generalization properties of NTP, irrespective of the specific architecture used to generate the context embeddings.

Implicit Optimization Bias of Next-Token Prediction in Linear Models

TL;DR

Previous research on implicit bias in one-hot classification in one-hot classification is extended to the NTP setting, highlighting key differences and prompting further research into the optimization and generalization properties of NTP, irrespective of the specific architecture used to generate the context embeddings.

Abstract

We initiate an investigation into the optimization properties of next-token prediction (NTP), the dominant training paradigm for modern language models. Specifically, we study the structural properties of the solutions selected by gradient-based optimizers among the many possible minimizers of the NTP objective. By framing NTP as cross-entropy minimization across distinct contexts, each tied with a sparse conditional probability distribution across a finite vocabulary of tokens, we introduce "NTP-separability conditions" that enable reaching the data-entropy lower bound. With this setup, and focusing on linear models with fixed context embeddings, we characterize the optimization bias of gradient descent (GD): Within the data subspace defined by the sparsity patterns of distinct contexts, GD selects parameters that equate the logits' differences of in-support tokens to their log-odds. In the orthogonal subspace, the GD parameters diverge in norm and select the direction that maximizes a margin specific to NTP. These findings extend previous research on implicit bias in one-hot classification to the NTP setting, highlighting key differences and prompting further research into the optimization and generalization properties of NTP, irrespective of the specific architecture used to generate the context embeddings.
Paper Structure (22 sections, 8 theorems, 80 equations, 7 figures)

This paper contains 22 sections, 8 theorems, 80 equations, 7 figures.

Key Result

Proposition 1

Assume training data $\mathcal{T}_m$ is NTP$_{\mathcal{H}}$-compatible and NTP-separable, with the respective matrices ${\bm{W}}^{\rm{p}}$ and ${\bm{W}}^{\rm{d}}$ satisfying conditions eq:Wf condition and eq:Wmm condition. While all finite ${\bm{W}}$ satisfy $\operatorname{CE}({\bm{W}})>\mathcal{H}$

Figures (7)

  • Figure 1: Vis. of NTP implicit optimization bias in a setting with $m=3$distinct contexts, embedding dimension $d=2$, vocabulary of $|\mathcal{V}|=5$ words and support sets of length $|{\mathcal{S}}_j|=3,j\in[3]$. Left: Vis. of context embeddings${\bar{{\bm{h}}}}_j$ in circle black markers (marked as A,B,C) and of their associated support sets ${\mathcal{S}}_j$ (colored text below each marker). Colored vectors (star markers) represent max-NTP-margin vectors ${\bm{w}}_v^\top:=\bm{e}_v^\top{\bm{W}}^{\rm{mm}}, v\in[5]$ found by GD. Interpreting decoder vectors as word embeddings leads to intuitive findings on their geometry learned by NTP training. E.g., word embedding ${\bm{w}}_3$ (almost) aligns with context-embedding $A$ and the normal hyperplane it defines separates $A$ from $B$ and $C$, since word $3$ only appears after context $A$. The rest of the words follow two contexts each and their word-representation naturally belongs to the cone defined by the embeddings of those respective contexts. The wider the cone, the larger the magnitude of the word embedding to compensate for the large angle between context-representations that share the same next-word. Note that geometry of depicted word embeddings only depends on support sets, but the conditional probabilities define another set of word representations on an orthogonal (matrix) subspace; see text for details and vis. Right: Upper/lower graphs confirm the predictions of Lemma \ref{['lem:norm growth']} and of Theorem \ref{['thm:GD main']}, respectively.
  • Figure 2: Same setup as Fig. \ref{['fig:2d']}. Left: Matrix ${\bm{P}}$ of conditional probabilities of words (cols.) per context (rows). Each row corresponds to the conditional probability vectors $\bm{p}_j, j\in[m]$. Black entries correspond to off-support words. Middle: Shown as ${\bm{w}}_z, z\in[5]$, the rows of the NTP-SVM solution ${\bm{W}}^{\rm{mm}}$ to which GD directionally converges. Right: Shown as ${\bm{w}}_z, z\in[5]$, the rows of the finite parameter ${\bm{W}}^\star$ to which GD iterates projected on $\mathscr{F}$ converge to. The geometry of ${\bm{W}}^{\rm{mm}}$ depends only on the support-set of ${\bm{P}}$. On the other hand, the geometry of ${\bm{W}}^\star$ depends on the entries of ${\bm{P}}$ for in-support tokens/words. As seen from visualization of ${\bm{P}}$, the words $1$ and $5$ have the same support pattern (i.e., both follow the same contexts $A$ and $B$). Thus, ${\bm{w}}_1={\bm{w}}_5$ in the Middle plot. However, on the subspace $\mathscr{F}$ corresponding to the Right plot, ${\bm{w}}_1\neq{\bm{w}}_5$, which allows matching the different conditional probabilities with which each follows contexts $A$ and $B$.
  • Figure 3: Eight randomly picked contexts with their associated next-token empirical conditional probabilities $\hat{\bm{p}}_j$. The indices shown on the x-axis define the support set ${\mathcal{S}}_j$ of each context.
  • Figure 4: Experimental illustration of the implicit bias of GD in NTP over synthetic data with overparameterization. See App. \ref{['sec:exp']} for detailed description of the experimental setting. The upper two graphs confirm the predictions of Lemma \ref{['lem:norm growth']}, while the lower two graphs adhere to the predictions of Theorem \ref{['thm:GD main']}.
  • Figure : NGD
  • ...and 2 more figures

Theorems & Definitions (16)

  • Definition 1: NTP$_{\mathcal{H}}$-compatible
  • Definition 2: NTP-separable
  • Proposition 1
  • Lemma 1: Overparameterization implies NTP-separability
  • Remark 1
  • Definition 3: NTP-SVM
  • Theorem 1: Implicit bias of the regularization-path
  • proof : Proof sketch (App. \ref{['app:proof reg-path']} for details)
  • Lemma 2: Norm growth
  • Theorem 2: Implicit bias of GD
  • ...and 6 more