Table of Contents
Fetching ...

MPX: Mixed Precision Training for JAX

Alexander Gräfe, Sebastian Trimpe

TL;DR

The paper presents MPX, a mixed-precision training toolbox for JAX, addressing the lack of native mixed-precision support. MPX leverages JAX type-promotion, PyTree casting, dynamic loss scaling, and Equinox-compatible gradient transformations to convert full-precision pipelines to mixed precision with minimal modifications, while preserving accuracy. It provides PyTree casting utilities, function wrappers, automatic loss scaling, gradient calculation helpers, and optimizer wrappers that skip updates when gradients overflow. Empirical evaluation on ViT models shows substantial reductions in RAM usage and training time, confirming MPX as a practical solution for efficient large-scale training in JAX ecosystems.

Abstract

Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training in recent years. Concurrently, JAX has grown in popularity as a versatile machine learning toolbox. However, it currently lacks robust support for mixed-precision training. We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks while preserving model accuracy. MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions with minimal modifications. By casting both inputs and outputs to half precision, and introducing a dynamic loss-scaling mechanism, MPX alleviates issues like gradient underflow and overflow that commonly arise in half precision computations. Its design inherits critical features from JAX's type-promotion behavior, ensuring that operations take place in the correct precision and allowing for selective enforcement of full precision where needed (e.g., sums, means, or softmax). MPX further provides wrappers for automatic creation and management of mixed-precision gradients and optimizers, enabling straightforward integration into existing JAX training pipelines. MPX's source code, documentation, and usage examples are available at github.com/Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX .

MPX: Mixed Precision Training for JAX

TL;DR

The paper presents MPX, a mixed-precision training toolbox for JAX, addressing the lack of native mixed-precision support. MPX leverages JAX type-promotion, PyTree casting, dynamic loss scaling, and Equinox-compatible gradient transformations to convert full-precision pipelines to mixed precision with minimal modifications, while preserving accuracy. It provides PyTree casting utilities, function wrappers, automatic loss scaling, gradient calculation helpers, and optimizer wrappers that skip updates when gradients overflow. Empirical evaluation on ViT models shows substantial reductions in RAM usage and training time, confirming MPX as a practical solution for efficient large-scale training in JAX ecosystems.

Abstract

Mixed-precision training has emerged as an indispensable tool for enhancing the efficiency of neural network training in recent years. Concurrently, JAX has grown in popularity as a versatile machine learning toolbox. However, it currently lacks robust support for mixed-precision training. We propose MPX, a mixed-precision training toolbox for JAX that simplifies and accelerates the training of large-scale neural networks while preserving model accuracy. MPX seamlessly integrates with popular toolboxes such as Equinox and Flax, allowing users to convert full-precision pipelines to mixed-precision versions with minimal modifications. By casting both inputs and outputs to half precision, and introducing a dynamic loss-scaling mechanism, MPX alleviates issues like gradient underflow and overflow that commonly arise in half precision computations. Its design inherits critical features from JAX's type-promotion behavior, ensuring that operations take place in the correct precision and allowing for selective enforcement of full precision where needed (e.g., sums, means, or softmax). MPX further provides wrappers for automatic creation and management of mixed-precision gradients and optimizers, enabling straightforward integration into existing JAX training pipelines. MPX's source code, documentation, and usage examples are available at github.com/Data-Science-in-Mechanical-Engineering/mixed_precision_for_JAX .

Paper Structure

This paper contains 13 sections, 2 figures.

Figures (2)

  • Figure 1: Comparison of GPU VRAM consumed for full precision and mixed precision as a function of the number of batches on the desktop PC.
  • Figure 2: Comparison of training step times for full precision and mixed precision as a function of the number of batches.