Table of Contents
Fetching ...

On the Inductive Bias of Stacking Towards Improving Reasoning

Nikunj Saunshi, Stefani Karp, Shankar Krishnan, Sobhan Miryoosefi, Sashank J. Reddi, Sanjiv Kumar

TL;DR

An intriguing phenomenon is discovered: MIDAS is not only training-efficient but surprisingly also has an inductive bias towards improving downstream tasks, especially tasks that require reasoning abilities like reading comprehension and math problems, despite having similar or slightly worse perplexity compared to baseline training.

Abstract

Given the increasing scale of model sizes, novel training strategies like gradual stacking [Gong et al., 2019, Reddi et al., 2023] have garnered interest. Stacking enables efficient training by gradually growing the depth of a model in stages and using layers from a smaller model in an earlier stage to initialize the next stage. Although efficient for training, the model biases induced by such growing approaches are largely unexplored. In this work, we examine this fundamental aspect of gradual stacking, going beyond its efficiency benefits. We propose a variant of gradual stacking called MIDAS that can speed up language model training by up to 40%. Furthermore we discover an intriguing phenomenon: MIDAS is not only training-efficient but surprisingly also has an inductive bias towards improving downstream tasks, especially tasks that require reasoning abilities like reading comprehension and math problems, despite having similar or slightly worse perplexity compared to baseline training. To further analyze this inductive bias, we construct reasoning primitives -- simple synthetic tasks that are building blocks for reasoning -- and find that a model pretrained with stacking is significantly better than standard pretraining on these primitives, with and without fine-tuning. This provides stronger and more robust evidence for this inductive bias towards reasoning. These findings of training efficiency and inductive bias towards reasoning are verified at 1B, 2B and 8B parameter language models. Finally, we conjecture the underlying reason for this inductive bias by exploring the connection of stacking to looped models and provide strong supporting empirical analysis.

On the Inductive Bias of Stacking Towards Improving Reasoning

TL;DR

An intriguing phenomenon is discovered: MIDAS is not only training-efficient but surprisingly also has an inductive bias towards improving downstream tasks, especially tasks that require reasoning abilities like reading comprehension and math problems, despite having similar or slightly worse perplexity compared to baseline training.

Abstract

Given the increasing scale of model sizes, novel training strategies like gradual stacking [Gong et al., 2019, Reddi et al., 2023] have garnered interest. Stacking enables efficient training by gradually growing the depth of a model in stages and using layers from a smaller model in an earlier stage to initialize the next stage. Although efficient for training, the model biases induced by such growing approaches are largely unexplored. In this work, we examine this fundamental aspect of gradual stacking, going beyond its efficiency benefits. We propose a variant of gradual stacking called MIDAS that can speed up language model training by up to 40%. Furthermore we discover an intriguing phenomenon: MIDAS is not only training-efficient but surprisingly also has an inductive bias towards improving downstream tasks, especially tasks that require reasoning abilities like reading comprehension and math problems, despite having similar or slightly worse perplexity compared to baseline training. To further analyze this inductive bias, we construct reasoning primitives -- simple synthetic tasks that are building blocks for reasoning -- and find that a model pretrained with stacking is significantly better than standard pretraining on these primitives, with and without fine-tuning. This provides stronger and more robust evidence for this inductive bias towards reasoning. These findings of training efficiency and inductive bias towards reasoning are verified at 1B, 2B and 8B parameter language models. Finally, we conjecture the underlying reason for this inductive bias by exploring the connection of stacking to looped models and provide strong supporting empirical analysis.
Paper Structure (30 sections, 2 equations, 6 figures, 7 tables, 1 algorithm)

This paper contains 30 sections, 2 equations, 6 figures, 7 tables, 1 algorithm.

Figures (6)

  • Figure 1: (a) Pictorial depiction of gradual stacking and Midas. (b) Accuracy improvements (in %) for model trained with Midas over baseline for various task groups, despite having the same perplexity. For both 1B, 2B and 8B models, we see that improvements are mostly positive, and are much larger for tasks that require a lot of reasoning.
  • Figure 2: (a) For an ALBert model trained with weight sharing across all layers, we measure the functional similarity between layers by looking at the top 1% activated neurons in each MLP layer and measure the intersection-over-union (IoU) metric for each pair of layers. Despite all layers having the same parameters, a natural functional similarity structure emerges around the middle. (b) For a UL2 model trained with GradStack, we measure the cosine similarity between every pair of layer blocks for the first feedforward layer weights. (c) The same similarity measured for Midas. The cosine similarities for stacking based models suggests strong connection to looped models, and Midas has a closer similarity structure to ALBert style looped models than GradStack.
  • Figure 3: Histogram of accuracy improvements for models trained with Midas over baseline. The data points are Midas 1B models listed in \ref{['table:main_results']}. The figure shows that Midas-based models have much higher improvement in contextual version of TyDiQA compared to the non-contextual version.
  • Figure 4: ownstream evalulation vs validation log perplexity isoplots as training proceeds for baseline and Midas 1B models trained on the same data (stacking is 24% faster here). On the y-axis we track the performance on various task groups -- closed book QA, open book QA, math word problems and our reasoning primitives from \ref{['sec:reasoning_primitives']}. On the x-axis the log perplexity is presented in the reverse order, thus downstream performance for both methods improves as log perplexity gets lower. For closed book QA (memorization) tasks Midas has very similar trends to baseline. For open book QA tasks and math word problems, Midas has much better downstream performance at an equivalent log perplexity. This showcases the inductive bias of Midas towards better overall quality and better reasoning abilities.
  • Figure 5: ccuracy improvements for model trained with Midas over baseline for representative reasoning primitives, despite having the same perplexity. We see clear improvements for stacking on almost all the primitives, both with 5-shot evaluation and after fine-tuning (FT) for the depth 2 primitive.
  • ...and 1 more figures

Theorems & Definitions (1)

  • Definition 3.1: Prop-$\alpha$ schedule