Table of Contents
Fetching ...

Continual Pre-training of MoEs: How robust is your router?

Benjamin Thérien, Charles-Étienne Joseph, Zain Sarwar, Ashwinee Panda, Anirban Das, Shi-Xiong Zhang, Stephen Rawls, Sambit Sahu, Eugene Belilovsky, Irina Rish

TL;DR

This work investigates continual pre-training of decoder-only MoE transformers under distribution shifts, evaluating two routing algorithms (Penalty-Balanced Top-$k$ and Sinkhorn-Balanced Top-$k$) across Granular and Switch MoE architectures. Using a large-scale setup with 600B tokens, the study compares MoEs to a FLOP-matched dense baseline and full re-training, focusing on forgetting, load balance via a new Maximum Routing Imbalance metric, and downstream performance with replay and infinite learning-rate schedules. Key findings show MoEs exhibit robust CPT with both routing methods, can match full re-training performance at substantially lower cost when replay and infinite LR are used, and preserve sample efficiency relative to dense models; routing dynamics reveal early layers drive adaptation while later layers stabilize. The results support deploying CPT on MoEs for scalable, adaptable foundation models, with Granular PB MoEs offering strong performance and favorable compute characteristics across domains like code and German text.

Abstract

Sparsely-activated Mixture of Experts (MoE) transformers are promising architectures for foundation models. Compared to dense transformers that require the same amount of floating-point operations (FLOPs) per forward pass, MoEs benefit from improved sample efficiency at training time and achieve much stronger performance. Many closed-source and open-source frontier language models have thus adopted an MoE architecture. Naturally, practitioners will want to extend the capabilities of these models with large amounts of newly collected data without completely re-training them. Prior work has shown that a simple combination of replay, learning rate re-warming, and re-decaying can enable the continual pre-training (CPT) of dense decoder-only transformers with minimal performance degradation compared to full re-training. In the case of decoder-only MoE transformers, however, it is unclear how the routing algorithm will impact continual pre-training performance: 1) do the MoE transformer's routers exacerbate forgetting relative to a dense model?; 2) do the routers maintain a balanced load on previous distributions after CPT?; 3) are the same strategies applied to dense models sufficient to continually pre-train MoE LLMs? In what follows, we conduct a large-scale study training a 500M parameter dense transformer and four 500M-active/2B-total parameter MoE transformers. Each model is trained for 600B tokens. Our results establish a surprising robustness to distribution shifts for MoEs using both Sinkhorn-Balanced and Z-and-Aux-loss-balanced routing algorithms, even in MoEs continually pre-trained without replay. Moreover, we show that MoE LLMs maintain their sample efficiency (relative to a FLOP-matched dense model) during CPT and that they can match the performance of a fully re-trained MoE at a fraction of the cost.

Continual Pre-training of MoEs: How robust is your router?

TL;DR

This work investigates continual pre-training of decoder-only MoE transformers under distribution shifts, evaluating two routing algorithms (Penalty-Balanced Top- and Sinkhorn-Balanced Top-) across Granular and Switch MoE architectures. Using a large-scale setup with 600B tokens, the study compares MoEs to a FLOP-matched dense baseline and full re-training, focusing on forgetting, load balance via a new Maximum Routing Imbalance metric, and downstream performance with replay and infinite learning-rate schedules. Key findings show MoEs exhibit robust CPT with both routing methods, can match full re-training performance at substantially lower cost when replay and infinite LR are used, and preserve sample efficiency relative to dense models; routing dynamics reveal early layers drive adaptation while later layers stabilize. The results support deploying CPT on MoEs for scalable, adaptable foundation models, with Granular PB MoEs offering strong performance and favorable compute characteristics across domains like code and German text.

Abstract

Sparsely-activated Mixture of Experts (MoE) transformers are promising architectures for foundation models. Compared to dense transformers that require the same amount of floating-point operations (FLOPs) per forward pass, MoEs benefit from improved sample efficiency at training time and achieve much stronger performance. Many closed-source and open-source frontier language models have thus adopted an MoE architecture. Naturally, practitioners will want to extend the capabilities of these models with large amounts of newly collected data without completely re-training them. Prior work has shown that a simple combination of replay, learning rate re-warming, and re-decaying can enable the continual pre-training (CPT) of dense decoder-only transformers with minimal performance degradation compared to full re-training. In the case of decoder-only MoE transformers, however, it is unclear how the routing algorithm will impact continual pre-training performance: 1) do the MoE transformer's routers exacerbate forgetting relative to a dense model?; 2) do the routers maintain a balanced load on previous distributions after CPT?; 3) are the same strategies applied to dense models sufficient to continually pre-train MoE LLMs? In what follows, we conduct a large-scale study training a 500M parameter dense transformer and four 500M-active/2B-total parameter MoE transformers. Each model is trained for 600B tokens. Our results establish a surprising robustness to distribution shifts for MoEs using both Sinkhorn-Balanced and Z-and-Aux-loss-balanced routing algorithms, even in MoEs continually pre-trained without replay. Moreover, we show that MoE LLMs maintain their sample efficiency (relative to a FLOP-matched dense model) during CPT and that they can match the performance of a fully re-trained MoE at a fraction of the cost.

Paper Structure

This paper contains 36 sections, 8 equations, 22 figures, 12 tables.

Figures (22)

  • Figure 1: Continually pre-trained (CPT) MoEs match the performance of full re-training across two dataset transitions: 400B English$\rightarrow$ 200B German ($40\%$ replay) and 400B English$\rightarrow$ 200B Stack (30% replay). We compare the performance of a fully re-trained (e.g. trained on the union of 400B English and 200B stack or 200B German) Penalty-Balanced Top-$k$ MoE and dense baseline, to their CPT counterparts. Despite incurring only a third of the substantial full-retraining cost, the CPT MoEs match the performance of the fully re-trained models, even achieving improvements in median Maximum Routing Imbalance (MRI) in some cases. This shows that MoEs have CPT abilities on par with dense transformers. Note that subfigures (b), (c), and (f) evaluate German and Stack models on different datasets which correspond to their training domain.
  • Figure 2: Ablating replay and decay strategy during continual pre-training on German data. We CPT MoEs and a dense baseline from fully-decayed checkpoints (dotted curves, [D]) or a non-decayed checkpoint (full curves). The figures report the performance on task 1 (FineWeb) and task 2 (German) while CPT on task 2. We observe that adaptation to task 2 is similar within an architecture for both checkpoints, that CPT from a non-decayed checkpoint improves forgetting, and that replay mitigates forgetting.
  • Figure 3: FineWeb $\rightarrow$ German CPT checkpoint and replay ablation. We report the median Maximum Routing Imbalance (MRI) across MoE layers with min/max error bars. Sinkhorn-Balanced (SBT$k$) MoEs show a slight MRI increase during distribution shift, while PBT$k$ MoEs experience a brief spike before recovering to balanced MRI levels below SBT$k$, which approach the uniform baseline. The uniform routing baseline (orange line) corresponds to the case where each expert across all layers receives the same number of tokens; thus, it represents perfect balance.
  • Figure 4: Layer-wise Maximum Routing Imbalance (MRI) for Granular MoEs. We report MRI (eq. \ref{['eq:mri']}) on each dataset’s 20M-token test set. MRI is consistently lower for Penalty-Balanced MoEs than Sinkhorn-Balanced MoEs. Continual pre-training on FineWeb incurs minimal MRI increase, even with $0\%$ replay. MoEs are most unbalanced with out-of-distribution data (e.g., non-German models in (b) and non-code models in (c)).
  • Figure 5: Layer-wise analysis of routing changes during CPT. Our goal is to understand how routing decisions change from the pre-trained checkpoints to final checkpoints after continual pre-training. To this end, we analyse changes in routing behaviour from $3$ perspectives: which experts tend to be activated together (a), the tendency for certain vocabulary tokens to be routed to certain experts (b), and how close routing decisions of the pre-trained checkpoint are from CPT checkpoints (c). To provide context to these metrics, we remind the reader of the forgetting (e.g., from table \ref{['tab:summary']}) for each model shown in the plots. We observe that the no-replay baseline changes the most in early layers and forgets the most, suggesting that more drastic changes in initial layers may be linked to forgetting.
  • ...and 17 more figures