A distributional simplicity bias in the learning dynamics of transformers
Riccardo Rende, Federica Gerace, Alessandro Laio, Sebastian Goldt
TL;DR
The paper addresses whether transformers exhibit a distributional simplicity bias when trained on NLP tasks by developing a cloning framework that creates data sets with controlled many-body token interactions. Using multi-layer factored attention with a quadratic nonlinearity, the authors show that depth sets the maximum interaction order accessible to the model, enabling principled exploration of learning dynamics under MLM and next-token prediction. Empirically, transformers first learn lower-order interactions and progressively acquire higher-order dependencies, with distinct plateaus corresponding to approximations of increasing order in synthetic data and consistent sequential learning patterns observed in BERT and GPT on TinyStories and WikiText-103. The findings provide a mechanistic view of generalization in transformers and a versatile tool for probing how interaction order shapes learning across NLP tasks, with potential extensions to other data modalities and theoretical analyses of learning timescales.
Abstract
The remarkable capability of over-parameterised neural networks to generalise effectively has been explained by invoking a ``simplicity bias'': neural networks prevent overfitting by initially learning simple classifiers before progressing to more complex, non-linear functions. While simplicity biases have been described theoretically and experimentally in feed-forward networks for supervised learning, the extent to which they also explain the remarkable success of transformers trained with self-supervised techniques remains unclear. In our study, we demonstrate that transformers, trained on natural language data, also display a simplicity bias. Specifically, they sequentially learn many-body interactions among input tokens, reaching a saturation point in the prediction error for low-degree interactions while continuing to learn high-degree interactions. To conduct this analysis, we develop a procedure to generate \textit{clones} of a given natural language data set, which rigorously capture the interactions between tokens up to a specified order. This approach opens up the possibilities of studying how interactions of different orders in the data affect learning, in natural language processing and beyond.
