Table of Contents
Fetching ...

TetraJet-v2: Accurate NVFP4 Training for Large Language Models with Oscillation Suppression and Outlier Control

Yuxiang Chen, Xiaoming Xu, Pengle Zhang, Michael Beyer, Martin Rapp, Jun Zhu, Jianfei Chen

TL;DR

TetraJet-v2 tackles the high cost of pre-training large language models by delivering end-to-end FP4 training with NVFP4 for activations, weights, and gradients. The key innovations are unbiased double-block quantization for NVFP4 linear layers, OsciReset to suppress weight oscillation, and OutControl to manage activation/outlier effects via Random Hadamard Transform and selective precision retention. Empirical results on LLMs up to 370M parameters and 200B tokens show consistent improvements over prior FP4 methods and a substantial reduction (about 51%) in the performance gap to full-precision training. This work advances practical FP4 pre-training by addressing both weight dynamics and outlier handling, with potential impact for hardware-aware, low-cost training of larger models.

Abstract

Large Language Models (LLMs) training is prohibitively expensive, driving interest in low-precision fully-quantized training (FQT). While novel 4-bit formats like NVFP4 offer substantial efficiency gains, achieving near-lossless training at such low precision remains challenging. We introduce TetraJet-v2, an end-to-end 4-bit FQT method that leverages NVFP4 for activations, weights, and gradients in all linear layers. We identify two critical issues hindering low-precision LLM training: weight oscillation and outliers. To address these, we propose: 1) an unbiased double-block quantization method for NVFP4 linear layers, 2) OsciReset, an algorithm to suppress weight oscillation, and 3) OutControl, an algorithm to retain outlier accuracy. TetraJet-v2 consistently outperforms prior FP4 training methods on pre-training LLMs across varying model sizes up to 370M and data sizes up to 200B tokens, reducing the performance gap to full-precision training by an average of 51.3%.

TetraJet-v2: Accurate NVFP4 Training for Large Language Models with Oscillation Suppression and Outlier Control

TL;DR

TetraJet-v2 tackles the high cost of pre-training large language models by delivering end-to-end FP4 training with NVFP4 for activations, weights, and gradients. The key innovations are unbiased double-block quantization for NVFP4 linear layers, OsciReset to suppress weight oscillation, and OutControl to manage activation/outlier effects via Random Hadamard Transform and selective precision retention. Empirical results on LLMs up to 370M parameters and 200B tokens show consistent improvements over prior FP4 methods and a substantial reduction (about 51%) in the performance gap to full-precision training. This work advances practical FP4 pre-training by addressing both weight dynamics and outlier handling, with potential impact for hardware-aware, low-cost training of larger models.

Abstract

Large Language Models (LLMs) training is prohibitively expensive, driving interest in low-precision fully-quantized training (FQT). While novel 4-bit formats like NVFP4 offer substantial efficiency gains, achieving near-lossless training at such low precision remains challenging. We introduce TetraJet-v2, an end-to-end 4-bit FQT method that leverages NVFP4 for activations, weights, and gradients in all linear layers. We identify two critical issues hindering low-precision LLM training: weight oscillation and outliers. To address these, we propose: 1) an unbiased double-block quantization method for NVFP4 linear layers, 2) OsciReset, an algorithm to suppress weight oscillation, and 3) OutControl, an algorithm to retain outlier accuracy. TetraJet-v2 consistently outperforms prior FP4 training methods on pre-training LLMs across varying model sizes up to 370M and data sizes up to 200B tokens, reducing the performance gap to full-precision training by an average of 51.3%.

Paper Structure

This paper contains 43 sections, 10 equations, 7 figures, 8 tables, 3 algorithms.

Figures (7)

  • Figure 1: The distribution of latent weight $w/s$ in OLMo2-150M blocks.11.att_proj in NVFP4 training without oscillation suppression.
  • Figure 2: Optimization trajectory of one oscillating weight in OLMo2-150M blocks.11.att_proj near the end of NVFP4 training.
  • Figure 3: Activation magnitudes of MLP input at layer 10 for different GSM8K samples across OLMo2-370M training checkpoints at different steps. Outliers consistently appear in specific channels.
  • Figure 4: Validation loss curve of OLMo2-370M with about 200B tokens for comparing different methods. (TJ: training method TetraJet-v2; OC: OutControl; OR: OsciReset)
  • Figure 5: Validation loss curve of OLMo2-150M with about 100B tokens for comparing different oscillation suppressing techniques. We set $T_{\rm start}\approx \mathtt{65B}$ tokens for all methods to begin suppressing oscillation.
  • ...and 2 more figures