Table of Contents
Fetching ...

Understanding Transformers via N-gram Statistics

Timothy Nguyen

TL;DR

A simple method to detect overfitting during training without using a holdout set, a quantitative measure of how transformers progress from learning simple to more complex statistical rules over the course of training, a model-variance criterion governing when transformer predictions tend to be described by N-gram rules, and insights into how well transformers can be approximated by N-gram rulesets in the limit where these rulesets become increasingly complex are obtained.

Abstract

Transformer based large-language models (LLMs) display extreme proficiency with language yet a precise understanding of how they work remains elusive. One way of demystifying transformer predictions would be to describe how they depend on their context in terms of simple template functions. This paper takes a first step in this direction by considering families of functions (i.e. rules) formed out of simple N-gram based statistics of the training data. By studying how well these rulesets approximate transformer predictions, we obtain a variety of novel discoveries: a simple method to detect overfitting during training without using a holdout set, a quantitative measure of how transformers progress from learning simple to more complex statistical rules over the course of training, a model-variance criterion governing when transformer predictions tend to be described by N-gram rules, and insights into how well transformers can be approximated by N-gram rulesets in the limit where these rulesets become increasingly complex. In this latter direction, we find that for 79% and 68% of LLM next-token distributions on TinyStories and Wikipedia, respectively, their top-1 predictions agree with those provided by our N-gram rulesets.

Understanding Transformers via N-gram Statistics

TL;DR

A simple method to detect overfitting during training without using a holdout set, a quantitative measure of how transformers progress from learning simple to more complex statistical rules over the course of training, a model-variance criterion governing when transformer predictions tend to be described by N-gram rules, and insights into how well transformers can be approximated by N-gram rulesets in the limit where these rulesets become increasingly complex are obtained.

Abstract

Transformer based large-language models (LLMs) display extreme proficiency with language yet a precise understanding of how they work remains elusive. One way of demystifying transformer predictions would be to describe how they depend on their context in terms of simple template functions. This paper takes a first step in this direction by considering families of functions (i.e. rules) formed out of simple N-gram based statistics of the training data. By studying how well these rulesets approximate transformer predictions, we obtain a variety of novel discoveries: a simple method to detect overfitting during training without using a holdout set, a quantitative measure of how transformers progress from learning simple to more complex statistical rules over the course of training, a model-variance criterion governing when transformer predictions tend to be described by N-gram rules, and insights into how well transformers can be approximated by N-gram rulesets in the limit where these rulesets become increasingly complex. In this latter direction, we find that for 79% and 68% of LLM next-token distributions on TinyStories and Wikipedia, respectively, their top-1 predictions agree with those provided by our N-gram rulesets.
Paper Structure (24 sections, 12 equations, 13 figures, 20 tables)

This paper contains 24 sections, 12 equations, 13 figures, 20 tables.

Figures (13)

  • Figure 1: Illustration of rule approximation. Given a context, different $N$-gram based rules formed out of the context will yield different next-token predictive distributions. In the above example, the context consists of three tokens. The first rule uses all three tokens of the context and makes a prediction based on the corresponding $4$-gram rule derived from the training data; the second rule uses only the first and last tokens to form a corresponding $3$-gram rule (and so the next token "slept" will be assigned less weight than the first rule since the "tired" token is ignored); and the third rule makes a prediction using the $N$-gram statistics obtained from aggregating over three token contexts from the training data where the second token is arbitrary (i.e. the second token is marginalized). Given a list of such rules, one can ask which rule's predictive distribution best matches that of the transformer.
  • Figure 2: TinyStories $7$-grams. Every point in the above plots represents a $7$-gram context. Shaded regions are obtained by bucketing along the x-axis and computing one standard deviation within the mean along the y-axis. Slope and $R^2$ values of plots are with respect to the linear fit of the data. Optimal rule distances and model variances are computed with respect to five model runs. (a): $d(p(t |C), p_{\textrm{full}}(t |C))$ vs count of $C$. (b): $d(p(t |C) ,p_{\textrm{full}}(t |C))$ vs model variance. (c): model variance vs count of $C$. (d): similar to (b) but now the y-axis is optimal rule distance of the optimal rule from $\mathcal{R}^{\textrm{suffix}}_7$. Model size: 160M.
  • Figure 3: Training Dynamics.Left: Models reach their lowest distance to more complex rules later in training. For rules with four tokens of context or less, the variational distance eventually starts increasing later in training. For six and seven tokens of context, the variational distance continues to decrease. Center & Right: The optimal rule selected always has nonincreasing distance and nondecreasing top1-accuracy relative to the ground truth (interpreted as a one-hot distribution $p_{\mathrm{gt}}$), despite distances to model predictions eventually increasing or plateauing for rules with less than six tokens of context. This shows that the optimal rule selection is improving with additional training even if the optimal rule distance with respect to model predictions is not improving. (One can imagine the rule predictions as a mesh in probability space, with LLM predictions navigating this space through training. The distance to the mesh may plateau but which rule is closest can continue to change.) Model size: 160M.
  • Figure 4: Overfitting Detection. We plot both train loss (solid lines) and validation loss (dashed lines) for the full transformer and limited context length transformers (the latter are marked with an "x" for emphasis) on TinyStories. Unlike the full transformer which overfits, those with limited context length have train and validation loss curves closely following each other. Model size: 1.4B.
  • Figure 5: Rule selection for a TinyStories validation sequence. The above is a sequence from a heldout story. In the second and third columns are the ground truth, token by token, along with the rule context (as defined in Section \ref{['sec:Ngram']}) associated to the optimal rule from $\mathcal{R}_7^{\textrm{all}}$. The heatmap indicates the variational distance between optimal rule and LLM next token distributions at the given token. The first column shows at most two tokens, which are chosen as follows: If the LLM top-1 prediction disagrees with the ground truth, the LLM prediction is shown. If in addition, the rule selected makes a different top-1 prediction from the transformer, that token is shown as the second token and the corresponding ground truth token is colored red. Thus red tokens are precisely the locations of disagreement between LLM and optimal rule greedy predictions. The last column shows the number of contexts supporting the optimal rule. Model size: 160M.
  • ...and 8 more figures