Table of Contents
Fetching ...

Efficient Stagewise Pretraining via Progressive Subnetworks

Abhishek Panigrahi, Nikunj Saunshi, Kaifeng Lyu, Sobhan Miryoosefi, Sashank Reddi, Satyen Kale, Sanjiv Kumar

TL;DR

A principled stagewise training framework, progressive subnetwork training, which only trains subnetworks within the model and progressively increases the size of subnetworks during training, until it trains the full network and proposes an instantiation of this framework - Random Part Training (RAPTR) - that selects and trains only a random subnetwork at each step.

Abstract

Recent developments in large language models have sparked interest in efficient pretraining methods. Stagewise training approaches to improve efficiency, like gradual stacking and layer dropping (Reddi et al, 2023; Zhang & He, 2020), have recently garnered attention. The prevailing view suggests that stagewise dropping strategies, such as layer dropping, are ineffective, especially when compared to stacking-based approaches. This paper challenges this notion by demonstrating that, with proper design, dropping strategies can be competitive, if not better, than stacking methods. Specifically, we develop a principled stagewise training framework, progressive subnetwork training, which only trains subnetworks within the model and progressively increases the size of subnetworks during training, until it trains the full network. We propose an instantiation of this framework - Random Part Training (RAPTR) - that selects and trains only a random subnetwork (e.g. depth-wise, width-wise) of the network at each step, progressively increasing the size in stages. We show that this approach not only generalizes prior works like layer dropping but also fixes their key issues. Furthermore, we establish a theoretical basis for such approaches and provide justification for (a) increasing complexity of subnetworks in stages, conceptually diverging from prior works on layer dropping, and (b) stability in loss across stage transitions in presence of key modern architecture components like residual connections and layer norms. Through comprehensive experiments, we demonstrate that RAPTR can significantly speed up training of standard benchmarks like BERT and UL2, up to 33% compared to standard training and, surprisingly, also shows better downstream performance on UL2, improving QA tasks and SuperGLUE by 1.5%; thereby, providing evidence of better inductive bias.

Efficient Stagewise Pretraining via Progressive Subnetworks

TL;DR

A principled stagewise training framework, progressive subnetwork training, which only trains subnetworks within the model and progressively increases the size of subnetworks during training, until it trains the full network and proposes an instantiation of this framework - Random Part Training (RAPTR) - that selects and trains only a random subnetwork at each step.

Abstract

Recent developments in large language models have sparked interest in efficient pretraining methods. Stagewise training approaches to improve efficiency, like gradual stacking and layer dropping (Reddi et al, 2023; Zhang & He, 2020), have recently garnered attention. The prevailing view suggests that stagewise dropping strategies, such as layer dropping, are ineffective, especially when compared to stacking-based approaches. This paper challenges this notion by demonstrating that, with proper design, dropping strategies can be competitive, if not better, than stacking methods. Specifically, we develop a principled stagewise training framework, progressive subnetwork training, which only trains subnetworks within the model and progressively increases the size of subnetworks during training, until it trains the full network. We propose an instantiation of this framework - Random Part Training (RAPTR) - that selects and trains only a random subnetwork (e.g. depth-wise, width-wise) of the network at each step, progressively increasing the size in stages. We show that this approach not only generalizes prior works like layer dropping but also fixes their key issues. Furthermore, we establish a theoretical basis for such approaches and provide justification for (a) increasing complexity of subnetworks in stages, conceptually diverging from prior works on layer dropping, and (b) stability in loss across stage transitions in presence of key modern architecture components like residual connections and layer norms. Through comprehensive experiments, we demonstrate that RAPTR can significantly speed up training of standard benchmarks like BERT and UL2, up to 33% compared to standard training and, surprisingly, also shows better downstream performance on UL2, improving QA tasks and SuperGLUE by 1.5%; thereby, providing evidence of better inductive bias.
Paper Structure (76 sections, 27 theorems, 98 equations, 7 figures, 11 tables, 3 algorithms)

This paper contains 76 sections, 27 theorems, 98 equations, 7 figures, 11 tables, 3 algorithms.

Key Result

Lemma 3.1

For a small enough learning rate, 2-phase RaPTr first learns lower degree component and then the higher degree component.

Figures (7)

  • Figure 1: Pictorial description of stagewise RaPTr where the number of layers being skipped progressively decreases over stages.
  • Figure 2: valuation loss (left) and component error (3 on the right) on basis polynomials of different degrees for a $20$ layer residual network trained with different methods. Labels are generated from a composition of polynomials of degrees $1$ to $10$ (Eq.\ref{['eq:true_label']}). The schedules for RaPTr and PLD are selected to have 20% fewer FLOPs compared to baseline (phase transitions for RaPTr have been marked with dark vertical lines). (more details in §\ref{['sec:toy_setting']}). Observations: (a) RaPTr reaches same evaluation loss of baseline, while PLD performs much worse. (b) RaPTr learns lower order terms faster and picks up higher order terms in the later stages. PLD is worse at capturing higher degree terms owing to its reduced expressivity towards the end.
  • Figure 3: (a): Training trajectory of BERT under $4$-stage 6-8-10-12 RaPTr, with the stage transitions (denoted by arrows); (b), (c), (d): Stability study on BERT-Large trained for $50k$ steps by RaPTr with subnetworks of length $L-1$. Behavior of (b) norms of intermediate activations ${\bm{y}}^{(\ell)}$, (c) $\Psi_{\ell} / \Psi_{1}$ (\ref{['eq:lipstacklayers']}), and (d) Loss gap between different random subnetwork $F_{-\ell}$ and model $F$, given by $\mathcal{L}(F) - \frac{1}{L} \sum_{\ell=1}^{L} \mathcal{L}(F_{-\ell})$. Key observations: (b) Norms of the intermediate activations grow linearly with $\ell$, (c) $\Psi_{\ell}$ changes slowly with $\ell$ as $(\frac{L}{\ell})^{0.12}$, suggesting a worse-case bound of $\mathcal{O}(L^{-0.88})$ on $\mathcal{L}(F) - \frac{1}{L} \sum_{\ell=1}^{L} \mathcal{L}(F_{-\ell})$ based on \ref{['thm:err_prop_inf']} (d) Interestingly, $\mathcal{L}(F) \le \frac{1}{L} \sum_{\ell=1}^{L} \mathcal{L}(F_{-\ell})$, even when model is trained with $L-1$ random subnetworks.
  • Figure 4: ehavior on a linear residual network with normalization layers with $100$ random samples from $\mathbb{S}^{d-1}$, and dimension $d=100$. The parameters of each layer $\ell$ is represented as $\sqrt{\tau} {\bm{A}} + \sqrt{1-\tau} {\bm{G}}^{(\ell)}$ for a shared matrix ${\bm{A}} \in \mathbb{R}^{d \times d}$ with $\left\|{\bm{A}}\right\|_2 \le 1$ and ${\bm{G}}^{(\ell)} \sim \mathcal{N}\left(0, d^{-1/2}{\bm{I}}\right)$. Left to right: Behavior of (a) the norms of intermediate activation ${\bm{y}}^{(\ell)}$ with index $\ell$, (b) $\Psi_{\ell}$ (\ref{['eq:lipstacklayers']}) for each stack of layers $F_{\ell:L}$, and (c) $\frac{1}{L} \sum_{\ell=1}^{L} \Psi_{\ell} ({\bm{x}}) / \left\| F({\bm{x}}) \right\|_2$ that appears in our bounds in \ref{['thm:err_prop_inf']}.
  • Figure 5: Train and Evaluation Loss behavior for a BERT-Base model trained with RaPTr for 100k steps. We have $4$ stages with 6-8-10-12 schedule (see \ref{['sec:experiments']} for details). The boundaries are at 25k, 50k, and 75k. Key observation: the model's train and evaluation loss change smoothly across stage transitions, indicating stability of the model to subnetwork training.
  • ...and 2 more figures

Theorems & Definitions (52)

  • Definition 2.1: $L$-layer sequence-to-sequence model
  • Definition 2.2: $(p, \mathcal{I})$-subnetwork
  • Lemma 3.1: Informal, cf \ref{['thm:raptr_stagewise']}
  • Definition 5.2
  • Theorem 5.3: Informal, cf \ref{['thm:err_prop']}
  • Lemma 5.4
  • Theorem H.1
  • proof : Proof of \ref{['thm:err_prop_inf']}
  • Lemma H.2: Norm of the output of linear layers at initialization
  • proof
  • ...and 42 more