Table of Contents
Fetching ...

A distributional simplicity bias in the learning dynamics of transformers

Riccardo Rende, Federica Gerace, Alessandro Laio, Sebastian Goldt

TL;DR

The paper addresses whether transformers exhibit a distributional simplicity bias when trained on NLP tasks by developing a cloning framework that creates data sets with controlled many-body token interactions. Using multi-layer factored attention with a quadratic nonlinearity, the authors show that depth sets the maximum interaction order accessible to the model, enabling principled exploration of learning dynamics under MLM and next-token prediction. Empirically, transformers first learn lower-order interactions and progressively acquire higher-order dependencies, with distinct plateaus corresponding to approximations of increasing order in synthetic data and consistent sequential learning patterns observed in BERT and GPT on TinyStories and WikiText-103. The findings provide a mechanistic view of generalization in transformers and a versatile tool for probing how interaction order shapes learning across NLP tasks, with potential extensions to other data modalities and theoretical analyses of learning timescales.

Abstract

The remarkable capability of over-parameterised neural networks to generalise effectively has been explained by invoking a ``simplicity bias'': neural networks prevent overfitting by initially learning simple classifiers before progressing to more complex, non-linear functions. While simplicity biases have been described theoretically and experimentally in feed-forward networks for supervised learning, the extent to which they also explain the remarkable success of transformers trained with self-supervised techniques remains unclear. In our study, we demonstrate that transformers, trained on natural language data, also display a simplicity bias. Specifically, they sequentially learn many-body interactions among input tokens, reaching a saturation point in the prediction error for low-degree interactions while continuing to learn high-degree interactions. To conduct this analysis, we develop a procedure to generate \textit{clones} of a given natural language data set, which rigorously capture the interactions between tokens up to a specified order. This approach opens up the possibilities of studying how interactions of different orders in the data affect learning, in natural language processing and beyond.

A distributional simplicity bias in the learning dynamics of transformers

TL;DR

The paper addresses whether transformers exhibit a distributional simplicity bias when trained on NLP tasks by developing a cloning framework that creates data sets with controlled many-body token interactions. Using multi-layer factored attention with a quadratic nonlinearity, the authors show that depth sets the maximum interaction order accessible to the model, enabling principled exploration of learning dynamics under MLM and next-token prediction. Empirically, transformers first learn lower-order interactions and progressively acquire higher-order dependencies, with distinct plateaus corresponding to approximations of increasing order in synthetic data and consistent sequential learning patterns observed in BERT and GPT on TinyStories and WikiText-103. The findings provide a mechanistic view of generalization in transformers and a versatile tool for probing how interaction order shapes learning across NLP tasks, with potential extensions to other data modalities and theoretical analyses of learning timescales.

Abstract

The remarkable capability of over-parameterised neural networks to generalise effectively has been explained by invoking a ``simplicity bias'': neural networks prevent overfitting by initially learning simple classifiers before progressing to more complex, non-linear functions. While simplicity biases have been described theoretically and experimentally in feed-forward networks for supervised learning, the extent to which they also explain the remarkable success of transformers trained with self-supervised techniques remains unclear. In our study, we demonstrate that transformers, trained on natural language data, also display a simplicity bias. Specifically, they sequentially learn many-body interactions among input tokens, reaching a saturation point in the prediction error for low-degree interactions while continuing to learn high-degree interactions. To conduct this analysis, we develop a procedure to generate \textit{clones} of a given natural language data set, which rigorously capture the interactions between tokens up to a specified order. This approach opens up the possibilities of studying how interactions of different orders in the data affect learning, in natural language processing and beyond.

Paper Structure

This paper contains 24 sections, 17 equations, 5 figures, 1 table.

Figures (5)

  • Figure 1: Transformers learn increasingly higher-order interactions from their data.Left: We illustrate the idea of a statistical "clone" of a data set, which approximates the underlying data distribution by keeping only interactions between tokens up to a fixed degree (in this case, three-body interactions). We introduce a principled approach to create clones by training a transformer with multiple layers of factored self-attention rende2024mapping with $x^2$ activation function between layers. The depth of the architecture controls the degree of the approximation. Clones can then be sampled from these models. Right: Test loss of a standard BERT-like transformer encoder devlin2019bertliu2019roberta with four attention blocks trained on the WikiText-103 merity2016pointersentinelmixturemodels data set and tested on clones of this data set with a truncated maximum degree of many-body interactions between tokens. We show the average over five training runs starting from the same initial condition. The shaded area indicates one standard deviation.
  • Figure 2: a) Multi-layer factored self-attention architecture with $x^2$ activation function. b) Test loss learning curves of one, two and three factored self-attention layers with $x^2$ activation function. The models were trained on a synthetic data set generated from a four-body Hamiltonian. The dashed horizontal lines correspond to the convergence value of the loss for two, three and four bodies energy based models trained on the same data set. c) Mean Square Displacement of the weights across different layers in a three-layers factored attention architecture. In these experiments, the size of the vocabulary was set to $|\mathbb{V}|=10$ and the sequence length to $L=20$. We used a training set of $M=25600$ samples, training the models with SGD, choosing a mini-batch size of $256$. The initial learning rate is chosen to be $0.1$.
  • Figure 3: Three steps for cloning a data set using factored-attention based generative models. a) Train factored-attention models on TinyStories. Test loss curves of different factored-attention based architectures trained on TinyStories and tested on TinyStories. Specifically, we consider architectures with two, four and six factored self-attention layers with $x^2$ activation function. For comparison, also the test loss of a four-layers BERT is shown. b) Sample factored models. Mean score of a batch of sentences taken from the test set of the TinyStories data set and evolved with the Metropolis-Hasting sampling scheme described in \ref{['app:sampling']}. c) Check generated clones. Test loss curves of a standard four layers transformer encoder, trained on TinyStories and tested on clones generated after $20$ and $70$ Metropolis-Hasting sweeps. The clones were generated from a four layers standard BERT and from an architecture with four layers of factored self-attention and $x^2$ activation function (associated with a nine bodies approximation of TinyStories).
  • Figure 4: BERT models trained on masked-language modelling learn increasingly higher-order interactions during training.Left panel: In an experiment analogous to the one shown in \ref{['fig:figure1']}, we show the test loss of a standard BERT-like transformer encoder trained trained on the TinyStories data set eldan2023tinystories and tested on clones of this data set with a truncated maximum degree of many-body interactions between tokens. The inset shows the corresponding test accuracy. We show the average over five different training runs, all starting from the same initial condition. The shaded area indicates one standard deviation. Right panel: An alternative way to visualise the data from the left panel is to plot the test loss at steps $10^4$, $3\times10^4,$ and $10^5$ (blue, green and orange points respectively). This visualisation highlights the sequential learning of higher-order interactions, showing that for the clones derived from two- and four-layer factored architectures the loss saturates after $3\times10^4$ training steps, while on the clones derived from a six-layer architecture, as well as for the clone sampled from a BERT model, the test loss continues to decrease, as indicated by the black arrows.
  • Figure 5: GPT models trained on next token prediction tasks learn increasingly higher-order interactions during training.Left panel: Test loss of a two-layer GPT-Neo model from eldan2023tinystories trained on TinyStories set and tested on clones of this data set with a truncated maximum degree of many-body interactions between tokens. Right panel: We repeat the analysis of \ref{['fig:sequential']} for the GPT-Neo model trained using next-token prediction, showing again clear signatures of sequential learning. The results are consistent with respect to different random seeds.