Table of Contents
Fetching ...

MoC-System: Efficient Fault Tolerance for Sparse Mixture-of-Experts Model Training

Weilin Cai, Le Qin, Jiayi Huang

TL;DR

MoC-System tackles the escalating fault-tolerance cost of training sparse MoE models by introducing Partial Experts Checkpointing (PEC), which saves only a subset of MoE experts, combined with fully sharded checkpointing and a two-level asynchronous management scheme. The approach significantly reduces checkpoint size and overhead (e.g., up to 98.9% per-checkpoint overhead reduction and large-scale throughput gains) while preserving or even improving downstream task accuracy (average gains around $1.08\%$). A new metric, Proportion of Lost Tokens ($PLT$), quantifies accuracy loss due to PEC and guides adaptive strategies such as Sequential vs Load-aware selection and Dynamic-K adjustments under fault load. The results, demonstrated on Megatron-DeepSpeed with ZeRO-2 DP + EP and on GPT-350M-16E, show that PEC and two-level, asynchronous, fully sharded checkpointing enable robust fault tolerance with markedly lower overhead, making MoE training more practical at scale. The work highlights a broader potential for algorithm-system co-design in fault tolerance, suggesting future sparsity-aware co-design strategies for LLM training and deployment.

Abstract

As large language models continue to scale up, distributed training systems have expanded beyond 10k nodes, intensifying the importance of fault tolerance. Checkpoint has emerged as the predominant fault tolerance strategy, with extensive studies dedicated to optimizing its efficiency. However, the advent of the sparse Mixture-of-Experts (MoE) model presents new challenges due to the substantial increase in model size, despite comparable computational demands to dense models. In this work, we propose the Mixture-of-Checkpoint System (MoC-System) to orchestrate the vast array of checkpoint shards produced in distributed training systems. MoC-System features a novel Partial Experts Checkpointing (PEC) mechanism, an algorithm-system co-design that strategically saves a selected subset of experts, effectively reducing the MoE checkpoint size to levels comparable with dense models. Incorporating hybrid parallel strategies, MoC-System involves fully sharded checkpointing strategies to evenly distribute the workload across distributed ranks. Furthermore, MoC-System introduces a two-level checkpointing management method that asynchronously handles in-memory snapshots and persistence processes. We build MoC-System upon the Megatron-DeepSpeed framework, achieving up to a 98.9% reduction in overhead for each checkpointing process compared to the original method, during MoE model training with ZeRO-2 data parallelism and expert parallelism. Additionally, extensive empirical analyses substantiate that our methods enhance efficiency while maintaining comparable model accuracy, even achieving an average accuracy increase of 1.08% on downstream tasks.

MoC-System: Efficient Fault Tolerance for Sparse Mixture-of-Experts Model Training

TL;DR

MoC-System tackles the escalating fault-tolerance cost of training sparse MoE models by introducing Partial Experts Checkpointing (PEC), which saves only a subset of MoE experts, combined with fully sharded checkpointing and a two-level asynchronous management scheme. The approach significantly reduces checkpoint size and overhead (e.g., up to 98.9% per-checkpoint overhead reduction and large-scale throughput gains) while preserving or even improving downstream task accuracy (average gains around ). A new metric, Proportion of Lost Tokens (), quantifies accuracy loss due to PEC and guides adaptive strategies such as Sequential vs Load-aware selection and Dynamic-K adjustments under fault load. The results, demonstrated on Megatron-DeepSpeed with ZeRO-2 DP + EP and on GPT-350M-16E, show that PEC and two-level, asynchronous, fully sharded checkpointing enable robust fault tolerance with markedly lower overhead, making MoE training more practical at scale. The work highlights a broader potential for algorithm-system co-design in fault tolerance, suggesting future sparsity-aware co-design strategies for LLM training and deployment.

Abstract

As large language models continue to scale up, distributed training systems have expanded beyond 10k nodes, intensifying the importance of fault tolerance. Checkpoint has emerged as the predominant fault tolerance strategy, with extensive studies dedicated to optimizing its efficiency. However, the advent of the sparse Mixture-of-Experts (MoE) model presents new challenges due to the substantial increase in model size, despite comparable computational demands to dense models. In this work, we propose the Mixture-of-Checkpoint System (MoC-System) to orchestrate the vast array of checkpoint shards produced in distributed training systems. MoC-System features a novel Partial Experts Checkpointing (PEC) mechanism, an algorithm-system co-design that strategically saves a selected subset of experts, effectively reducing the MoE checkpoint size to levels comparable with dense models. Incorporating hybrid parallel strategies, MoC-System involves fully sharded checkpointing strategies to evenly distribute the workload across distributed ranks. Furthermore, MoC-System introduces a two-level checkpointing management method that asynchronously handles in-memory snapshots and persistence processes. We build MoC-System upon the Megatron-DeepSpeed framework, achieving up to a 98.9% reduction in overhead for each checkpointing process compared to the original method, during MoE model training with ZeRO-2 data parallelism and expert parallelism. Additionally, extensive empirical analyses substantiate that our methods enhance efficiency while maintaining comparable model accuracy, even achieving an average accuracy increase of 1.08% on downstream tasks.
Paper Structure (36 sections, 16 equations, 15 figures, 4 tables)

This paper contains 36 sections, 16 equations, 15 figures, 4 tables.

Figures (15)

  • Figure 1: An illustration of the model states, including model parameters (a) and optimizer states (b), across three ranks in distributed training. The training utilizes the hybrid parallel strategy of ZeRO-2 DP + EP, configured with the parallel degree of DP = 3 and EP = 3. The non-expert parts are depicted in green, while the expert parts are depicted in yellow, with varying shades denoting different experts within the same MoE layer. The combination of white and green in the non-expert modules in (b) illustrates the partitioning of states across ranks through ZeRO-2 DP. "Atten0" and "FFN0" represent Attention and FFN sublayers in the $0th$ transformer layer, while "Atten1" and the MoE layer, including "Expert(1-0, 1-1, 1-2)", are in the $1th$ transformer layer.
  • Figure 2: An illustration of fault tolerance in model training through checkpoint mechanism. The checkpointing interval $I_{ckpt}$ is set to 10 iterations. A fault arises following the 30th iteration, before the completion of the third checkpoint. Therefore, the most recent completed checkpoint (ckpt2) is loaded to recover the training progress. The composition of a checkpoint is depicted on the left, with the size of each component reflecting its data volume, using the GPT-350M-16E model as an example.
  • Figure 3: The top half part illustrates the two-phase checkpointing workflow (GPU-to-CPU snapshot + CPU-to-Storage persist) during a distributed training. The training employs 4-degree DP across two nodes, each equipped with two GPUs. Data-parallel sharding is utilized to minimize the volume of data saved per DP rank. The bottom half part presents a timeline for asynchronous checkpointing, where "F&B" denotes the forward and backward passes of an iteration, "U" denotes a weight update, "S" denotes a checkpoint stall.
  • Figure 4: An illustration of our proposed partial experts checkpointing (PEC) with sequential selection. At the current checkpointing, "Expert(1-0, 3-1, 5-2, 7-0)" are saved, while those not saved are marked in white. Blue arrows indicate the iterative pattern of the sequential selection, which will save "Expert(1-1, 3-2, 5-0, 7-1)" at the next checkpointing.
  • Figure 5: Correlation analysis between (a) the Proportion of Lost Tokens (PLT) and (b) the final validation loss. In (a), the PLT centers on 3.75% observed in a PEC configuration of $K_{pec}=2$ and $I_{ckpt}=32$, which slightly degrades the model accuracy compared to the non-fault case. The validation losses are presented in (b), where the non-fault case's loss of 4.8851 is taken as the center value to highlight the accuracy deviations under various PEC configurations.
  • ...and 10 more figures