Table of Contents
Fetching ...

Towards a theory of how the structure of language is acquired by deep neural networks

Francesco Cagnetta, Matthieu Wyart

TL;DR

This work determines token-token correlations analytically in their model and shows that they can be used to build a representation of the grammar's hidden variables, and conjecture that the relationship between training set size and effective range of correlations holds beyond the authors' synthetic datasets.

Abstract

How much data is required to learn the structure of a language via next-token prediction? We study this question for synthetic datasets generated via a Probabilistic Context-Free Grammar (PCFG) -- a tree-like generative model that captures many of the hierarchical structures found in natural languages. We determine token-token correlations analytically in our model and show that they can be used to build a representation of the grammar's hidden variables, the longer the range the deeper the variable. In addition, a finite training set limits the resolution of correlations to an effective range, whose size grows with that of the training set. As a result, a Language Model trained with increasingly many examples can build a deeper representation of the grammar's structure, thus reaching good performance despite the high dimensionality of the problem. We conjecture that the relationship between training set size and effective range of correlations holds beyond our synthetic datasets. In particular, our conjecture predicts how the scaling law for the test loss behaviour with training set size depends on the length of the context window, which we confirm empirically in Shakespeare's plays and Wikipedia articles.

Towards a theory of how the structure of language is acquired by deep neural networks

TL;DR

This work determines token-token correlations analytically in their model and shows that they can be used to build a representation of the grammar's hidden variables, and conjecture that the relationship between training set size and effective range of correlations holds beyond the authors' synthetic datasets.

Abstract

How much data is required to learn the structure of a language via next-token prediction? We study this question for synthetic datasets generated via a Probabilistic Context-Free Grammar (PCFG) -- a tree-like generative model that captures many of the hierarchical structures found in natural languages. We determine token-token correlations analytically in our model and show that they can be used to build a representation of the grammar's hidden variables, the longer the range the deeper the variable. In addition, a finite training set limits the resolution of correlations to an effective range, whose size grows with that of the training set. As a result, a Language Model trained with increasingly many examples can build a deeper representation of the grammar's structure, thus reaching good performance despite the high dimensionality of the problem. We conjecture that the relationship between training set size and effective range of correlations holds beyond our synthetic datasets. In particular, our conjecture predicts how the scaling law for the test loss behaviour with training set size depends on the length of the context window, which we confirm empirically in Shakespeare's plays and Wikipedia articles.
Paper Structure (96 sections, 125 equations, 9 figures)

This paper contains 96 sections, 125 equations, 9 figures.

Figures (9)

  • Figure 1: Left: Example of data generation according to the RHM, with depth $L\,{=}\,3$ and branching factor $s\,{=}\,2$. Starting from the root with $\ell\,{=}\,3$ and following the arrows, each level-$\ell$ symbol is replaced with a pair of lower-level symbols, down to the leaves with $\ell\,{=}\,0$. Right: Empirical (coloured) and analytical (black dashed) correlation functions of RHM data, with $L\,{=}\,3$, $s\,{=}\,2$, $v\,{=}\,32$ and $m\,{=}\,8$. The stepwise decay mirrors the tree structure of the generative model. Empirical estimates obtained from $P$ examples initially follow the true correlation function, but then saturate due to the sampling noise (coloured dashed). As a result, a finite training set only allows for measuring correlations with the tokens up to a certain distance $t^*(P)$. Graphically, $t^*(P)$ corresponds to the highest value of $t$ where the empirical estimate matches the true correlation (e.g. $1$ for the orange and green curves, $3$ for the red curve).
  • Figure 2: Left: Learning curves of depth-$3$ transformers trained on RHM data with $L\,{=}\,3$, $s\,{=}\,2$, $v\,{=}\,32$ and $m\,{=}\,8$ (blue) or $11$ (orange, both are averaged over $8$ independent realisations of the dataset and initialisations of the network), displaying a stepwise behaviour analogous to the correlation function. The vertical dashed lines mark the characteristic training set sizes $P_{k}$ at which the correlation with tokens at distances up to $t\,{=}\,s^k-1$ emerge from the sampling noise. Horizontal dashed lines represent (upper bounds on) the cross-entropy of the probability of the last token conditioned on the previous $s^k\,{-}\,1$, suggesting that the steps correspond to the model learning a progressively larger sub-tree of the data structure. Right: Learning curves of transformers for $m\,{=}\,8$ and different sizes $t$ of the context window. The saturation of the loss decay due to the finite context window highlights that the decay is entirely due to the ability to leverage a larger portion of the context window.
  • Figure 3: Relative sensitivity $r_{\ell}/s_{\ell}$ of the representation of trained depth-$4$ CNNs (sketched on the right panels) for input transformations (the affected tokens are indicated by the black horizontal segments on the right panels) corresponding to resetting the production rule emanating from a given level-$\ell$ variable ($\ell=1,2,3$ for top, centre and bottom), as a function of training set size $P$. Colours represent the layer of the representation, as indicated in the key and by the squares on the right panels. The CNNs are trained on RHM data with $L\,{=}\,4$, $s\,{=}\,2$, $v\,{=}\,16$, $m\,{=}\,4$. Vertical dashed lines mark the sample complexities $P_{\ell}$ of \ref{['eq:prediction-sample-complexity']}. The drop of the curves from $\simeq 1$ to $\simeq 0$ around $P_{\ell}$ signals that the trained representations only encode for the relevant level-$\ell$ symbol when $P\,{>}\,P_{\ell}$.
  • Figure 4: Top, Left: Test losses of $3$-layers transformers trained on $(t\,{+}\,1)$-characters blocks of the tiny-Shakespeare dataset karpathy2015shakespeare ($t$ as in the key). The saturation of the loss to some $t$-dependent value indicates that performance improves with $P$ because the model can use information from a larger context window. Top, Right: Empirical estimates $\hat{C}_P(t)$ for different training set sizes $P$ as in the key. The curves initially follow the true correlation $\tilde{C}(t)$ (black dashed), but then saturate due to the sampling noise (coloured dashed). Bottom, Right: The empirical curves $\hat{C}_P(t)$ collapse when rescaling correlations by the sampling noise size $P^{-1/2}$ and $t$ by the characteristic distance $t^*(P)\sim P^{1/z}$, with $z\simeq 2.8$. Bottom, Left: As predicted by our conjecture, the losses collapse when rescaled according to \ref{['eq:scaling']} with the same $z$ as the correlation functions.
  • Figure 5: Top, Left: Test losses of $6$-layers transformers trained on $(t\,{+}\,1)$-characters blocks of the WikiText-103 merity2017pointer ($t$ as in the key). As in \ref{['fig:test-shakespeare']}, the loss saturates to some $t$-dependent value after reaching a characteristic training set size. Top, Right: Empirical correlation functions $\hat{C}_P(t)$ with $P$ as in the key, showing saturation for large $t$ due to the sampling noise (coloured dashed). Bottom, Right: Collapse of the empirical curves $\hat{C}_P(t)$ is achieved when rescaling the correlations by the sampling noise size $P^{-1/2}$ and $t$ by the characteristic distance $t^*(P)\sim P^{1/z}$, with $z\,{=}\,2\beta\simeq 3.1$. Bottom, Left: As predicted by our conjecture, the losses collapse when rescaled according to \ref{['eq:scaling']} with the same $z$ as the correlation functions and $\alpha\simeq 0.095$.
  • ...and 4 more figures