Table of Contents
Fetching ...

Scalify: scale propagation for efficient low-precision LLM training

Paul Balança, Sam Hosegood, Carlo Luschi, Andrew Fitzgibbon

TL;DR

Scalify tackles the challenge of training large models with low-precision formats by introducing an end-to-end scale propagation transform that carries a scale factor alongside data across the full computation graph, including the optimizer state. The core idea is the scaled array representation $X = X_d \cdot X_s$ and a set of propagation rules that preserve unit input variance $\mathbf{E}[X_d^2] \simeq 1$ while keeping scales as powers of two. Empirically, Scalify enables out-of-the-box FP8 matmul and gradient representations and FP16 optimizer-state storage with minimal dynamic rescaling, avoiding bespoke kernels and complex state management. The approach yields memory and compute efficiency while maintaining accuracy in GPT-2–style experiments, and it is released as open-source in JAX for broader adoption.

Abstract

Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify

Scalify: scale propagation for efficient low-precision LLM training

TL;DR

Scalify tackles the challenge of training large models with low-precision formats by introducing an end-to-end scale propagation transform that carries a scale factor alongside data across the full computation graph, including the optimizer state. The core idea is the scaled array representation and a set of propagation rules that preserve unit input variance while keeping scales as powers of two. Empirically, Scalify enables out-of-the-box FP8 matmul and gradient representations and FP16 optimizer-state storage with minimal dynamic rescaling, avoiding bespoke kernels and complex state management. The approach yields memory and compute efficiency while maintaining accuracy in GPT-2–style experiments, and it is released as open-source in JAX for broader adoption.

Abstract

Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at https://github.com/graphcore-research/jax-scalify
Paper Structure (16 sections, 7 equations, 1 figure, 3 tables, 7 algorithms)