Table of Contents
Fetching ...

DrJAX: Scalable and Differentiable MapReduce Primitives in JAX

Keith Rush, Zachary Charles, Zachary Garrett, Sean Augenstein, Nicole Mitchell

TL;DR

DrJAX presents a differentiable MapReduce framework embedded in JAX that treats partitioned data as first-class citizens, enabling scalable sharding, JIT compilation, and AD while preserving partition metadata for cross-platform translation. By implementing MapReduce primitives as JAX primitives and lifting them through XLA, it decouples partition size from physical sharding and supports forward- and reverse-mode AD, allowing complex distributed algorithms to be written succinctly and efficiently. The approach demonstrates near-constant weak scaling on large transformer models and shows that internal sharding annotations are essential for performance, with empirical evidence contrasting against JIT-alone or GSPMD-only baselines. The work also emphasizes interpretability, preserving partition information in jaxprs to facilitate translation to federated and production systems, and discusses self-tuning, hypergradient opportunities enabled by MapReduce AD. Overall, DrJAX enables scalable, differentiable MapReduce computations and provides a path toward cross-platform deployment and automated optimization of distributed ML workloads.

Abstract

We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.

DrJAX: Scalable and Differentiable MapReduce Primitives in JAX

TL;DR

DrJAX presents a differentiable MapReduce framework embedded in JAX that treats partitioned data as first-class citizens, enabling scalable sharding, JIT compilation, and AD while preserving partition metadata for cross-platform translation. By implementing MapReduce primitives as JAX primitives and lifting them through XLA, it decouples partition size from physical sharding and supports forward- and reverse-mode AD, allowing complex distributed algorithms to be written succinctly and efficiently. The approach demonstrates near-constant weak scaling on large transformer models and shows that internal sharding annotations are essential for performance, with empirical evidence contrasting against JIT-alone or GSPMD-only baselines. The work also emphasizes interpretability, preserving partition information in jaxprs to facilitate translation to federated and production systems, and discusses self-tuning, hypergradient opportunities enabled by MapReduce AD. Overall, DrJAX enables scalable, differentiable MapReduce computations and provides a path toward cross-platform deployment and automated optimization of distributed ML workloads.

Abstract

We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.
Paper Structure (21 sections, 6 figures, 1 table)

This paper contains 21 sections, 6 figures, 1 table.

Figures (6)

  • Figure 1: DrJAX's representation of a non-partitioned array (left) and an array partitioned over 3 groups (right).
  • Figure 2: A partitioned structure in DrJAX with 3 groups. Each leaf is a partitioned array.
  • Figure 3: A high-level depiction of DrJAX building blocks operating on and transforming non-partitioned and partitioned arrays.
  • Figure 4: Total training time for 100 rounds of local SGD on various transformer language models sizes, with varying partition sizes.
  • Figure 5: Total training time for 100 rounds of local SGD, with varying partition sizes. We implement local SGD using DrJAX and a python for-loop which we JIT compile.
  • ...and 1 more figures