Towards Fully FP8 GEMM LLM Training at Scale
Alejandro Hernández-Cano, Dhia Garbaya, Imanol Schlag, Martin Jaggi
TL;DR
This work addresses FP8-based training for large language models at scale, proposing the FOG architecture family to enable FP8 GEMMs across all transformer block operations, including attention. By removing pre-normalization, introducing attention-time normalization, and applying post-normalization with careful activation choices, FOG achieves stable FP8DPA training and up to 43% throughput gains at 8B scale while maintaining BF16-level downstream performance. The paper introduces kurtosis as a diagnostic to monitor long-term outlier dynamics and to predict potential divergences, enabling more reliable FP8 training across long data regimes and MoE variants. Empirically, FOG variants converge where prior FP8 approaches diverge, demonstrate stable long-data training up to 450B tokens, and exhibit robustness across activation functions and MoE configurations, marking a significant step toward fully FP8 GEMM training in modern LLMs.
Abstract
Despite the significant potential of FP8 data formats for large language model (LLM) pre-training, their adoption has been limited due to challenges in maintaining stability at scale. Existing approaches often rely on suboptimal fine-grained FP8 kernels or fall back to higher-precision matrix multiplications (GEMMs) in sensitive components, such as attention projections, compromising potential throughput gains. We introduce a new class of LLM architectures that, for the first time, support FP8 computation for all GEMMs within transformer blocks during both forward and backward passes. This enables unprecedented throughput gains, particularly at scale, while matching the downstream performance of standard BF16 training. Our architecture design reduces large outlier activations, promoting stable long-term FP8 training. In addition, we identify key metrics to monitor low-precision training and predict potential future divergences.
