Table of Contents
Fetching ...

Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks

Tianyu He, Darshil Doshi, Aritra Das, Andrey Gromov

TL;DR

This work empirically shows that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases, and finds an algorithmic shift in deeper models, as they go from few to many in-context examples.

Abstract

Large language models can solve tasks that were not present in the training set. This capability is believed to be due to in-context learning and skill composition. In this work, we study the emergence of in-context learning and skill composition in a collection of modular arithmetic tasks. Specifically, we consider a finite collection of linear modular functions $z = a \, x + b \, y \;\mathrm{mod}\; p$ labeled by the vector $(a, b) \in \mathbb{Z}_p^2$. We use some of these tasks for pre-training and the rest for out-of-distribution testing. We empirically show that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases. We find that the smallest model capable of out-of-distribution generalization requires two transformer blocks, while for deeper models, the out-of-distribution generalization phase is \emph{transient}, necessitating early stopping. Finally, we perform an interpretability study of the pre-trained models, revealing highly structured representations in both attention heads and MLPs; and discuss the learned algorithms. Notably, we find an algorithmic shift in deeper models, as we go from few to many in-context examples.

Learning to grok: Emergence of in-context learning and skill composition in modular arithmetic tasks

TL;DR

This work empirically shows that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases, and finds an algorithmic shift in deeper models, as they go from few to many in-context examples.

Abstract

Large language models can solve tasks that were not present in the training set. This capability is believed to be due to in-context learning and skill composition. In this work, we study the emergence of in-context learning and skill composition in a collection of modular arithmetic tasks. Specifically, we consider a finite collection of linear modular functions labeled by the vector . We use some of these tasks for pre-training and the rest for out-of-distribution testing. We empirically show that a GPT-style transformer exhibits a transition from in-distribution to out-of-distribution generalization as the number of pre-training tasks increases. We find that the smallest model capable of out-of-distribution generalization requires two transformer blocks, while for deeper models, the out-of-distribution generalization phase is \emph{transient}, necessitating early stopping. Finally, we perform an interpretability study of the pre-trained models, revealing highly structured representations in both attention heads and MLPs; and discuss the learned algorithms. Notably, we find an algorithmic shift in deeper models, as we go from few to many in-context examples.
Paper Structure (44 sections, 4 equations, 29 figures)

This paper contains 44 sections, 4 equations, 29 figures.

Figures (29)

  • Figure 1: (a) The dataset. The tasks are labeled by vectors $(a,b)\in \mathbb Z_p^2$. Each table contains examples of $ax + by \,\, \textrm{mod}\,\, p$. A fraction $1 - \alpha$ of the examples is blacked out; while the remaining examples are flattened into a single "document" in the batch. Each document is organized as a collection of triples $(x,y,ax+by)$ for $x,y$ from the training set (i.e. not blacked out in the table). Our training is similar to the traditional next-token prediction (autoregressive); with the main difference that we predict every third token, which are marked in red ($x$ and $y$ are uncorrelated). Every task appears exactly the same number of times in each batch. (b) Phase diagram for a six-layer model. We find four different phases. (1) in-distribution memorization: The model only performs well on tasks $(a,b)$and examples $(x,y)$ from the training set -- it does not generalize on unseen examples or tasks. (2) in-distribution generalization: model generalizes on unseen examples $(x,y)$ but not on unseen tasks $(a,b)$. (3) out-of-distribution memorization: model generalizes on unseen tasks $(a,b)$ but only for examples $(x,y)$ it has seen during training. (4) out-of-distribution generalization: model generalizes on unseen tasks $(a,b)$ for seen as well as unseen examples $(x,y)$. We focus on investigating phase (4) in more detail. (c) In-context sample complexity. Accuracy of the model in phase (4) as a function of the number of few-shot examples. (d) Representations developed by one of the attention heads in the first layer. These are projections of the embedding of a pair of numbers onto the two largest principal components (PCs) of the internal representation formed after passing through the attention layer and projection matrix. (e) First 3 PCs of embeddings separate $log_{27}$-annotated numbers into even/odd planes, with 0 sandwiched between them.
  • Figure 2: Structured selection of pre-training tasks and sequences.
  • Figure 3: Phase diagram for the depth $d=6$ models. (a) Accuracy on all four sets used to plot the \ref{['fig:figure1']} phase diagram, with an early stopping applied. Notably, in the regions when models generalize to o.o.d. sets, the pre-training performance degrades; (b, c)$\alpha=0.6$ training accuracy and o.o.d. test accuracy (dotted line). For $n_{\mathrm{i.d.}}=2^8$, we notice that the o.o.d. generalization ability of the model first improves then degrades as we train longer; (d, e)$\alpha=0.6$, loss and accuracy vs context length, measured on $S^{\mathrm{o.o.d.}}_{\mathrm{test}}$ at the end of training, where for $n_{\mathrm{i.d.}}=2^8$ case the ICL ability fades away.
  • Figure 4: Phases of depth $d=4$ and $d=2$ models. With decreasing model capacity, the performance on both sets degrades. At the same time, the transient nature of ICL does not appear. (a, b) from left to right: accuracy phase diagrams on pre-training set $S^{\mathrm{i.d.}}_{\mathrm{train}}$ and on o.o.d. test set $S^{\mathrm{o.o.d.}}_{\mathrm{test}}$, with early stopping; loss and accuracy vs context length on o.o.d. test set $S^{\mathrm{o.o.d.}}_{\mathrm{test}}$ for $\alpha=0.6$.
  • Figure 5: $d = 4$ and $d=2$ models' performance on k-shot inference, on the grid of inputs $(x,y) \in \mathbb{Z}_p^2$ (task vector = $(6,6)$). row 1: Models' predictions on o.o.d. task of the type $(x_1 \;\; y_1 \;\; \textcolor{red}{z_1} \;\; \cdots \;\; x_k \;\; y_k \;\; \textcolor{red}{z_k} \;\; x \;\; y \;\; \textcolor{red}{?})$. row 2: Analytical plots showing predictions solely based on Modular Regression algorithm. row 3: Subtract row 2 from row 1, by using correct=1 and incorrect=0. The red points correspond to the examples where Ratio Matching does not give the correct predictions but the model predicts correctly. The blue points are examples that the model missed despite Ratio Matching being applicable. This row tells us about the model's ability to implement Modular Regression by combining the in-context examples. Note that $d=4$ model readily learns to combine previous examples, while its $d=2$ counterpart struggles due to its limited capacity.
  • ...and 24 more figures