Table of Contents
Fetching ...

Small Models, Smarter Learning: The Power of Joint Task Training

Csaba Both, Benjamin Hoover, Hendrik Strobelt, Dmitry Krotov, Daniel Karl I. Weidele, Mauro Martino, Nima Dehmamy

TL;DR

This work probes how training curriculum shapes learning in tiny transformer models solving nested ListOps arithmetic, revealing that the hardest single operation, SUM modulo $n$, can become tractable at smaller sizes when trained jointly with other operations like MAX, MIN, and MED. Using a nanoGPT-like architecture, CoT prompts, and detailed embedding analyses, the authors show that joint training induces number-like representations, parity-sensitivity, and different attention/FFN reliance, suggesting that task diversity guides efficient algorithm discovery. A key finding is that pure SUM often leads to memorization rather than number-structure extraction, while MIXED SUM or pretraining on MAX+MED enables SUM learning at substantially smaller parameter counts. Overall, the paper demonstrates that emergent abilities in language models depend not only on model size but crucially on curriculum design, which can steer the model toward more efficient, generalizable algorithms.

Abstract

The ability of a model to learn a task depends strongly on both the task difficulty and the model size. We aim to understand how task difficulty relates to the minimum number of parameters required for learning specific tasks in small transformer models. Our study focuses on the ListOps dataset, which consists of nested mathematical operations. We gradually increase task difficulty by introducing new operations or combinations of operations into the training data. We observe that sum modulo n is the hardest to learn. Curiously, when combined with other operations such as maximum and median, the sum operation becomes easier to learn and requires fewer parameters. We show that joint training not only improves performance but also leads to qualitatively different model behavior. We show evidence that models trained only on SUM might be memorizing and fail to capture the number structure in the embeddings. In contrast, models trained on a mixture of SUM and other operations exhibit number-like representations in the embedding space, and a strong ability to distinguish parity. Furthermore, the SUM-only model relies more heavily on its feedforward layers, while the jointly trained model activates the attention mechanism more. Finally, we show that learning pure SUM can be induced in models below the learning threshold of pure SUM, by pretraining them on MAX+MED. Our findings indicate that emergent abilities in language models depend not only on model size, but also the training curriculum.

Small Models, Smarter Learning: The Power of Joint Task Training

TL;DR

This work probes how training curriculum shapes learning in tiny transformer models solving nested ListOps arithmetic, revealing that the hardest single operation, SUM modulo , can become tractable at smaller sizes when trained jointly with other operations like MAX, MIN, and MED. Using a nanoGPT-like architecture, CoT prompts, and detailed embedding analyses, the authors show that joint training induces number-like representations, parity-sensitivity, and different attention/FFN reliance, suggesting that task diversity guides efficient algorithm discovery. A key finding is that pure SUM often leads to memorization rather than number-structure extraction, while MIXED SUM or pretraining on MAX+MED enables SUM learning at substantially smaller parameter counts. Overall, the paper demonstrates that emergent abilities in language models depend not only on model size but crucially on curriculum design, which can steer the model toward more efficient, generalizable algorithms.

Abstract

The ability of a model to learn a task depends strongly on both the task difficulty and the model size. We aim to understand how task difficulty relates to the minimum number of parameters required for learning specific tasks in small transformer models. Our study focuses on the ListOps dataset, which consists of nested mathematical operations. We gradually increase task difficulty by introducing new operations or combinations of operations into the training data. We observe that sum modulo n is the hardest to learn. Curiously, when combined with other operations such as maximum and median, the sum operation becomes easier to learn and requires fewer parameters. We show that joint training not only improves performance but also leads to qualitatively different model behavior. We show evidence that models trained only on SUM might be memorizing and fail to capture the number structure in the embeddings. In contrast, models trained on a mixture of SUM and other operations exhibit number-like representations in the embedding space, and a strong ability to distinguish parity. Furthermore, the SUM-only model relies more heavily on its feedforward layers, while the jointly trained model activates the attention mechanism more. Finally, we show that learning pure SUM can be induced in models below the learning threshold of pure SUM, by pretraining them on MAX+MED. Our findings indicate that emergent abilities in language models depend not only on model size, but also the training curriculum.

Paper Structure

This paper contains 38 sections, 34 figures.

Figures (34)

  • Figure 1: Emergence of abilities in ListOps: Each plot shows the same group of small transformer models trained on a different mix of the four operations MAX, MIN, MED, and SUM. Red dots are models reaching more than 50% accuracy, and blue dots are less than 50%. The dashed green line is a logistic fit, and the yellow star indicates the transition point at 50%. The x-axis is the model size (number of parameters), and the plots are sorted in ascending order of transition points. The bottom panel shows a bar plot of the model sizes at the transition points, with each group distinguished by a different color.
  • Figure 2: PCA of embeddings: We choose all models which reached over $90$% test accuracy. Each row shows the average correlation matrix and top PCs for models trained on either a single operation, e.g. Pure MAX, or all mixtures involving a given operation, e.g. Mixed SUM. Interestingly, pure SUM does not show a discernible structure in the embeddings, whereas all other cases do. Notably, Mixed SUM models exhibit a prominent odd-even separation in PC5.
  • Figure 3: PC5 and 6 of Average Mixed SUM embedding correlations perfectly separate odd and even numbers.
  • Figure 4: Evolution of training loss, accuracy, and PCs of the embedding layer, Mod 26: The top shows the evolution of training loss (solid lines) and test accuracy (dashed lines) for models with an embedding dimension of 96 and 3 layers, trained either on pure SUM (blue) or on mixed MAX+MED+SUM (red). Models trained on MAX+MED+SUM were evaluated separately on pure MAX, MED, and SUM subsets; the corresponding accuracy curves are shown in green, purple, and orange, respectively. Curves are the mean on 3 runs, with shaded $\pm\sigma$. All models were trained for 20000 iterations. The figures beneath the main plot display PCA embeddings revealing that models trained on MAX+MED+SUM data progressively develop a structured representation of numerical concepts, accompanied by a steady decrease in loss. The also show a prominent parity separation emerging in PC2 and PC4. Parity is colored by red and blue. In contrast, models trained solely on SUM exhibit no clear structure in the embedding space and show long plateaus in the loss curve.
  • Figure 5: Learning SUM by finetuning MAX+MED: We train model much smaller than the sum-only learning transition (48 embedding, 2 layers). By switching the training data slowly from MAX+MED to pure SUM (never showing expression mixing all three) the model is able to learn SUM (blue) in this much lower parameter regime. In comparison, the pure SUM models (red) did not learn at this size.
  • ...and 29 more figures