Table of Contents
Fetching ...

An exactly solvable model for emergence and scaling laws in the multitask sparse parity problem

Yoonsoo Nam, Nayara Fonseca, Seok Hyeong Lee, Chris Mingard, Ard A. Louis

TL;DR

This work formulates an analytically tractable framework for emergence and scaling in multitask learning by representing skills as an orthogonal basis of functions and solving a multilinear model. It derives exact scaling laws for loss with respect to training time $T$, data $D$, parameters $N$, and compute $C=N\times T$, and demonstrates stage-like, sigmoidal skill emergence consistent with neural network observations. Calibrating the model on the first skill enables accurate prediction of subsequent skill emergences in a 2-layer MLP and a transformer, linking feature learning to emergent capabilities. The study further extends the model to account for data-shot learning ($D_c$-shot) and parameter-shot learning ($N_c$-shot), highlighting tradeoffs between data, compute, and representation capacity while acknowledging limitations due to decoupled dynamics. Overall, the results suggest that hierarchical, stage-like learning across a power-law distribution of skill frequencies can reproduce key qualitative and quantitative aspects of emergence in neural systems, offering a compact lens on how complex abilities arise with scale.

Abstract

Deep learning models can exhibit what appears to be a sudden ability to solve a new problem as training time, training data, or model size increases, a phenomenon known as emergence. In this paper, we present a framework where each new ability (a skill) is represented as a basis function. We solve a simple multi-linear model in this skill-basis, finding analytic expressions for the emergence of new skills, as well as for scaling laws of the loss with training time, data size, model size, and optimal compute. We compare our detailed calculations to direct simulations of a two-layer neural network trained on multitask sparse parity, where the tasks in the dataset are distributed according to a power-law. Our simple model captures, using a single fit parameter, the sigmoidal emergence of multiple new skills as training time, data size or model size increases in the neural network.

An exactly solvable model for emergence and scaling laws in the multitask sparse parity problem

TL;DR

This work formulates an analytically tractable framework for emergence and scaling in multitask learning by representing skills as an orthogonal basis of functions and solving a multilinear model. It derives exact scaling laws for loss with respect to training time , data , parameters , and compute , and demonstrates stage-like, sigmoidal skill emergence consistent with neural network observations. Calibrating the model on the first skill enables accurate prediction of subsequent skill emergences in a 2-layer MLP and a transformer, linking feature learning to emergent capabilities. The study further extends the model to account for data-shot learning (-shot) and parameter-shot learning (-shot), highlighting tradeoffs between data, compute, and representation capacity while acknowledging limitations due to decoupled dynamics. Overall, the results suggest that hierarchical, stage-like learning across a power-law distribution of skill frequencies can reproduce key qualitative and quantitative aspects of emergence in neural systems, offering a compact lens on how complex abilities arise with scale.

Abstract

Deep learning models can exhibit what appears to be a sudden ability to solve a new problem as training time, training data, or model size increases, a phenomenon known as emergence. In this paper, we present a framework where each new ability (a skill) is represented as a basis function. We solve a simple multi-linear model in this skill-basis, finding analytic expressions for the emergence of new skills, as well as for scaling laws of the loss with training time, data size, model size, and optimal compute. We compare our detailed calculations to direct simulations of a two-layer neural network trained on multitask sparse parity, where the tasks in the dataset are distributed according to a power-law. Our simple model captures, using a single fit parameter, the sigmoidal emergence of multiple new skills as training time, data size or model size increases in the neural network.
Paper Structure (79 sections, 27 theorems, 258 equations, 12 figures, 6 tables)

This paper contains 79 sections, 27 theorems, 258 equations, 12 figures, 6 tables.

Key Result

Lemma 1

Let the multilinear model eq:toy be trained with gradient flow on $D$ i.i.d samples for the setup in sec:setup (input distribution: eq:probs, target function: eq:target, and MSE loss: eq:loss). Let $k \leq N$ be a skill index in the multilinear model and the input distribution ($k \leq {n_{s}}$). Th and the skill loss is where $\eta$ is the learning rate and $d_k$ is the number of observations wi

Figures (12)

  • Figure 1: Predicting emergence. The skill strength $\mathcal{R}_k$, defined as the $k^{th}$ coefficient if a model is expanded in the basis of the skill functions ($g_k$), measures how well the $k^\textrm{th}$ skill is learned, and is plotted against (a) time $T$, (b) data set size $D$, and (c) number of parameters $N$ (width of the hidden layer). $\mathcal{R}_k$ is normalized by the target scale $S$ such that $\mathcal{R}_k/S = 1$ means zero skill loss. The dashed lines show the abrupt growth -- emergence -- of $5$ skills for a 2-layer MLP (\ref{['app:methods']}) trained on the multitask sparse parity problem with data power-law exponent $\alpha=0.6$ (shaded area indicate 1-standard deviation over at least $10$ runs). Solid lines are the predictions (\ref{['eq:toy_theo_extended', 'eq:d_c_shot', 'eq:param_emergence']}, respectively) from our multilinear model calibrated on the first skill (blue) only.
  • Figure 2: Scaling laws. The learning curve ($\mathcal{L}$ is the MSE loss) of the multilinear model (solid) and the theoretical power-law (dotted) for (a) time $T$, (b) data $D$, and (c) parameters $N$. Lower left legends show the condition (top) and the scaling law (bottom) where $\alpha+1$ is the exponent of the power-law input data (\ref{['eq:probs']}). See the appendices for 1) rigorous derivations of the theoretical scaling laws including the exponents, prefactors (e.g., $\mathcal{A}_N$ for $\mathcal{L}=\mathcal{A}_NN^{-\alpha}$), and conditions (\ref{['app:rigorous']}); 2) simplified derivations of the exponent only (\ref{['app:scaling']}); 3) details of the experiment (\ref{['app:scaling_detail']}).
  • Figure 3: Scaling law for optimal compute. The solid lines are the learning curves of the multilinear model as a function of compute $C=T \times N$ with varying parameters $N$ from $10^1$ (top plateau) to $10^4$ (bottom plateau). The dotted lines are optimal compute scaling laws with exponent $-\alpha/(\alpha+2)$ (\ref{['derivation:compute_scaling']}) and calculated prefactor constants (\ref{['app:rigorous']}). See \ref{['app:scaling_detail']} for details of the experiment. For a given $C$, we achieve the optimal tradeoff when $T$ is large enough to fit all $N$ skills (i.e. when the solid lines plateau). For the case $\alpha=0.3$, the optimal $C$ for the model decays faster than the power-law, see \ref{['derivation:time_scaling']}.
  • Figure 4: Transformer on multitask sparse parity task. We trained a transformer on the multitask sparse parity task with $\alpha=0.9$; see \ref{['app:methods']} for details. Left: An example of the time emergence (measued in steps) for the transformer in the $n_s=5$ setup. See \ref{['app:add_plots']} for enlarged plots showing the saturation of each skill in linear scale. Right: The $k^{th}$ skill's emergent time $\tau_{emerge}(k)$ (i.e. $\mathcal{R}_k(\tau_{emerge}(k))/S = 0.05$) as a function of $k$ (error bars indicate 1-standard deviation over $5$ runs). The emergent times follow a power law of $k^{\alpha+1}$, following the same relationship in the multilinear model (\ref{['eq:toy_theo']}).
  • Figure 5: Nonlinear dynamics of linear neural networks.(a): A two-layer undercomplete linear neural network, which is a multiplication of two matrices, where $d_2 <d_1$ and $d_2 <d_3$. (b): The $d_2$ independent modes of dynamics for linear neural network (\ref{['eq:saxe_decomposed']}). The product of parameters $a_kb_k$ are learnable parameters and vectors $u_k,v_k$ are obtained from SVD of the input-output correlation matrix $\Sigma$ (\ref{['eq:saxe_condition']}). (c): The temporal evolution of $a_kb_k$ under gradient descent, which follows a sigmoidal growth (\ref{['eq:saxe_analytic']}). Note that smaller $\lambda_k$ -- the singular value of $\Sigma$ -- results in a more delayed saturation of $a_kb_k$.
  • ...and 7 more figures

Theorems & Definitions (27)

  • Lemma 1
  • Corollary 1
  • Corollary 2
  • Lemma 2
  • Lemma 3
  • Proposition 1
  • Proposition 2
  • Lemma 4
  • Lemma 5
  • Proposition 3
  • ...and 17 more