Table of Contents
Fetching ...

AdaGC: Improving Training Stability for Large Language Model Pretraining

Guoxia Wang, Shuai Li, Congliang Chen, Jinle Zeng, Jiabin Yang, Tao Sun, Yanjun Ma, Dianhai Yu, Li Shen

TL;DR

AdaGC targets the training instability caused by loss spikes during large-scale pretraining. It introduces adaptive, per-parameter gradient clipping based on exponential moving averages of local gradient norms, enabling dynamic thresholds that respond to both temporal decay and parameter heterogeneity, while preserving the convergence rate $O(1/\sqrt{T})$ akin to Adam. Empirical validation across Llama-2 7B/13B and CLIP ViT-Base shows complete loss spike elimination, improved perplexities and convergence speed, and broad optimizer compatibility (e.g., AdamW and Lion). The method demonstrates strong cross-architecture and cross-modality generalization, with practical implications for more stable, cost-efficient large-scale pretraining. Overall, AdaGC provides a principled stabilization framework by coupling localized gradient control with EMA-based threshold adaptation, supported by theoretical convergence guarantees and extensive empirical evidence.

Abstract

Large Language Models (LLMs) face increasing loss spikes during scaling, undermining training stability and final performance. While gradient clipping mitigates this issue, traditional global approaches poorly handle parameter-specific gradient variations and decaying gradient norms. We propose **AdaGC**, an adaptive gradient clipping framework that automatically adjusts local thresholds per parameter through exponential moving average of gradient norms. Theoretical analysis proves AdaGC's convergence under non-convex conditions. Extensive experiments demonstrate significant improvements: On Llama-2 7B/13B, AdaGC completely eliminates loss spikes while reducing WikiText perplexity by 3.5% (+0.14pp LAMBADA accuracy) for 7B and achieving 0.65% lower training loss with 1.47% reduced validation perplexity for 13B compared to global clipping. For CLIP ViT-Base, AdaGC converges 25% faster than StableAdamW with full spike elimination. The method shows universal effectiveness across architectures (Llama-2 7B/13B) and modalities (CLIP), with successful integration into diverse optimizers like AdamW and Lion. Source code will be released on GitHub.

AdaGC: Improving Training Stability for Large Language Model Pretraining

TL;DR

AdaGC targets the training instability caused by loss spikes during large-scale pretraining. It introduces adaptive, per-parameter gradient clipping based on exponential moving averages of local gradient norms, enabling dynamic thresholds that respond to both temporal decay and parameter heterogeneity, while preserving the convergence rate akin to Adam. Empirical validation across Llama-2 7B/13B and CLIP ViT-Base shows complete loss spike elimination, improved perplexities and convergence speed, and broad optimizer compatibility (e.g., AdamW and Lion). The method demonstrates strong cross-architecture and cross-modality generalization, with practical implications for more stable, cost-efficient large-scale pretraining. Overall, AdaGC provides a principled stabilization framework by coupling localized gradient control with EMA-based threshold adaptation, supported by theoretical convergence guarantees and extensive empirical evidence.

Abstract

Large Language Models (LLMs) face increasing loss spikes during scaling, undermining training stability and final performance. While gradient clipping mitigates this issue, traditional global approaches poorly handle parameter-specific gradient variations and decaying gradient norms. We propose **AdaGC**, an adaptive gradient clipping framework that automatically adjusts local thresholds per parameter through exponential moving average of gradient norms. Theoretical analysis proves AdaGC's convergence under non-convex conditions. Extensive experiments demonstrate significant improvements: On Llama-2 7B/13B, AdaGC completely eliminates loss spikes while reducing WikiText perplexity by 3.5% (+0.14pp LAMBADA accuracy) for 7B and achieving 0.65% lower training loss with 1.47% reduced validation perplexity for 13B compared to global clipping. For CLIP ViT-Base, AdaGC converges 25% faster than StableAdamW with full spike elimination. The method shows universal effectiveness across architectures (Llama-2 7B/13B) and modalities (CLIP), with successful integration into diverse optimizers like AdamW and Lion. Source code will be released on GitHub.

Paper Structure

This paper contains 24 sections, 7 theorems, 33 equations, 10 figures, 2 tables, 1 algorithm.

Key Result

Theorem 4.1

Under mild assumptions, by selecting $\alpha_t = \mathcal{O}(1/\sqrt{T})$, $\beta_2 = 1- \mathcal{O}(1/T)$ and $\beta_1 < \sqrt{\beta_2}$, when $\tau$ is randomly chosen from $\{1,2,\cdots,T\}$ with equal probabilities, it holds that

Figures (10)

  • Figure 1: Empirical motivation for AdaGC: (a) Temporal threshold decay necessitates adaptive clipping, (b) Parameter-specific gradient spikes demand localized control, (c) Fine-grained clipping outperforms global approaches.
  • Figure 2: Study of AdaGC's hyperparameters on Llama-2 Tiny. (a) Relative clipping threshold analysis demonstrates $\lambda_{\text{rel}}=1.05$ achieves optimal gradient regulation. (b) EMA coefficient analysis reveals $\beta=0.98$ best balances historical consistency with rapid adaptation.
  • Figure 3: Large language model training analysis: (a) 7B model comparison shows AdaGC's loss spike elimination and performance gains, (b) 13B results demonstrate method scalability.
  • Figure 4: CLIP ViT-Base results: (a) Training loss trajectory comparison. AdaGC achieves 25% faster convergence (15k vs 20k steps) with complete loss spike elimination compared to StableAdamW. (b) Zero-shot recognition performance. Final accuracy improves 0.27pp (39.84% vs 39.57%) on ImageNet-1K, demonstrating cross-modal generalization.
  • Figure 5: Locality granularity analysis: (a) Global vs parameter-wise clipping on GPT-2 345M model demonstrates spike elimination capability, (b) Distributed shard-wise clipping on 7B model reveals parameter integrity requirements.
  • ...and 5 more figures

Theorems & Definitions (14)

  • Theorem 4.1
  • Remark 1.5
  • Lemma 1.6
  • proof
  • Lemma 1.7
  • proof
  • Lemma 1.8
  • proof
  • Lemma 1.9
  • proof
  • ...and 4 more