MPX: Mixed Precision Training for JAX
Alexander Gräfe, Sebastian Trimpe
TL;DR
The paper presents MPX, a mixed-precision training toolbox for JAX, addressing the lack of native mixed-precision support. MPX leverages JAX type-promotion, PyTree casting, dynamic loss scaling, and Equinox-compatible gradient transformations to convert full-precision pipelines to mixed precision with minimal modifications, while preserving accuracy. It provides PyTree casting utilities, function wrappers, automatic loss scaling, gradient calculation helpers, and optimizer wrappers that skip updates when gradients overflow. Empirical evaluation on ViT models shows substantial reductions in RAM usage and training time, confirming MPX as a practical solution for efficient large-scale training in JAX ecosystems.
Abstract
Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training in recent years. Concurrently, JAX has grown in popularity as a versatile machine learning toolbox. However, it currently lacks robust support for mixed-precision training. We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks while preserving model accuracy. MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions with minimal modifications. By casting both inputs and outputs to half precision, and introducing a dynamic loss-scaling mechanism, MPX alleviates issues like gradient underflow and overflow that commonly arise in half precision computations. Its design inherits critical features from JAX's type-promotion behavior, ensuring that operations take place in the correct precision and allowing for selective enforcement of full precision where needed (e.g., sums, means, or softmax). MPX further provides wrappers for automatic creation and management of mixed-precision gradients and optimizers, enabling straightforward integration into existing JAX training pipelines. MPX's source code, documentation, and usage examples are available at github.com/Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX .
