Table of Contents
Fetching ...

Understanding Scaling Laws with Statistical and Approximation Theory for Transformer Neural Networks on Intrinsically Low-dimensional Data

Alex Havrilla, Wenjing Liao

TL;DR

A novel statistical estimation and mathematical approximation theories for transformers when the input data are concentrated on a low-dimensional manifold and results rigorously show the intrinsic dimension of data to be a crucial quantity affecting transformer scaling laws in both theory and practice.

Abstract

When training deep neural networks, a model's generalization error is often observed to follow a power scaling law dependent both on the model size and the data size. Perhaps the best known example of such scaling laws are for transformer-based large language models, where networks with billions of parameters are trained on trillions of tokens of text. Yet, despite sustained widespread interest, a rigorous understanding of why transformer scaling laws exist is still missing. To answer this question, we establish novel statistical estimation and mathematical approximation theories for transformers when the input data are concentrated on a low-dimensional manifold. Our theory predicts a power law between the generalization error and both the training data size and the network size for transformers, where the power depends on the intrinsic dimension $d$ of the training data. Notably, the constructed model architecture is shallow, requiring only logarithmic depth in $d$. By leveraging low-dimensional data structures under a manifold hypothesis, we are able to explain transformer scaling laws in a way which respects the data geometry. Moreover, we test our theory with empirical observation by training LLMs on natural language datasets. We find the observed empirical data scaling laws closely agree with our theoretical predictions. Taken together, these results rigorously show the intrinsic dimension of data to be a crucial quantity affecting transformer scaling laws in both theory and practice.

Understanding Scaling Laws with Statistical and Approximation Theory for Transformer Neural Networks on Intrinsically Low-dimensional Data

TL;DR

A novel statistical estimation and mathematical approximation theories for transformers when the input data are concentrated on a low-dimensional manifold and results rigorously show the intrinsic dimension of data to be a crucial quantity affecting transformer scaling laws in both theory and practice.

Abstract

When training deep neural networks, a model's generalization error is often observed to follow a power scaling law dependent both on the model size and the data size. Perhaps the best known example of such scaling laws are for transformer-based large language models, where networks with billions of parameters are trained on trillions of tokens of text. Yet, despite sustained widespread interest, a rigorous understanding of why transformer scaling laws exist is still missing. To answer this question, we establish novel statistical estimation and mathematical approximation theories for transformers when the input data are concentrated on a low-dimensional manifold. Our theory predicts a power law between the generalization error and both the training data size and the network size for transformers, where the power depends on the intrinsic dimension of the training data. Notably, the constructed model architecture is shallow, requiring only logarithmic depth in . By leveraging low-dimensional data structures under a manifold hypothesis, we are able to explain transformer scaling laws in a way which respects the data geometry. Moreover, we test our theory with empirical observation by training LLMs on natural language datasets. We find the observed empirical data scaling laws closely agree with our theoretical predictions. Taken together, these results rigorously show the intrinsic dimension of data to be a crucial quantity affecting transformer scaling laws in both theory and practice.

Paper Structure

This paper contains 25 sections, 10 theorems, 200 equations, 7 figures, 2 tables.

Key Result

Theorem 1

Let $M,\tau, R, H_f > 0$, $0<\beta\le 1$, $d, D \in \mathbb{N}$, $\mathcal{M}$ and $f$ satisfy Assumption assumption:manifold_p and assumption:function_p respectively. Given $n$ training samples $\{(x_i, f(x_i))\}_{i=1}^n$ where $\{x_i\}_{i=1}^n$ are i.i.d. samples of a distribution $Q$ supported on in the empirical risk minimization eq:erm, where $O(\cdot)$ hides terms in $C_\mathcal{M}$ (the num

Figures (7)

  • Figure 1: Diagram of the transformer architecture constructed in Theorem \ref{['thm:approximation']}. $\textup{T}$ computes approximations of $f(x)$ on each local chart $U_n \subseteq \mathcal{M}$ by first projecting $x$ to the tangent coordinates in $\mathbb{R}^d$ via $\phi_n(x)$ and then approximating $f(x)$ with local Taylor polynomials. A shallow sub-network computes indicators $\textbf{1}_{U_n}$ for each local chart in parallel. The results of the two sub-networks are then multiplied together and summed to produce the final result. Here $H_i$ denotes the embedding matrix before the $i$th transformer block $\textup{B}_i$.
  • Figure 2: Observed and predicted data scaling laws on OpenWebText, The Stack-SQL, and Tiny Stories pretraining datasets. All estimates are close $(\pm 0.02)$ and appear to reflect varying levels of pretraining data complexity. Note:$\hat{\alpha}_D$ denotes the empirically observed data scaling exponent and ${\alpha}_D$ denotes the theoretically estimated exponent.
  • Figure 3: Observed and predicted model scaling laws in model size on GPT2 and Pythia scaling suites. $\alpha_N$ denotes the empirically observed scaling exponent, and $\hat{\alpha}_N$ denotes the theoretically predicted exponent. Note: we estimate $\alpha_N$ for GPT2 using OpenWebText.
  • Figure 4: Top left: Estimated ID vs. number of parameters. Top right: Estimated ID vs. the embedding dimension. Bottom left: Variation of estimated ID across model layers. Bottom right: Variation of estimated ID across context position.
  • Figure 5: Diagram of transformer block.
  • ...and 2 more figures

Theorems & Definitions (26)

  • Definition 1: Transformer Neural Network
  • Definition 2: Transformer Network Class
  • Theorem 1
  • Theorem 2
  • Lemma 1
  • Definition 3: Transformer Block
  • Definition 4: Attention
  • Definition 5: Feed-forward Layer
  • Definition 6: Feed-forward Network Class
  • Definition 7: Multi-headed Attention and Transformer Block Classes
  • ...and 16 more