DrJAX: Scalable and Differentiable MapReduce Primitives in JAX
Keith Rush, Zachary Charles, Zachary Garrett, Sean Augenstein, Nicole Mitchell
TL;DR
DrJAX presents a differentiable MapReduce framework embedded in JAX that treats partitioned data as first-class citizens, enabling scalable sharding, JIT compilation, and AD while preserving partition metadata for cross-platform translation. By implementing MapReduce primitives as JAX primitives and lifting them through XLA, it decouples partition size from physical sharding and supports forward- and reverse-mode AD, allowing complex distributed algorithms to be written succinctly and efficiently. The approach demonstrates near-constant weak scaling on large transformer models and shows that internal sharding annotations are essential for performance, with empirical evidence contrasting against JIT-alone or GSPMD-only baselines. The work also emphasizes interpretability, preserving partition information in jaxprs to facilitate translation to federated and production systems, and discusses self-tuning, hypergradient opportunities enabled by MapReduce AD. Overall, DrJAX enables scalable, differentiable MapReduce computations and provides a path toward cross-platform deployment and automated optimization of distributed ML workloads.
Abstract
We present DrJAX, a JAX-based library designed to support large-scale distributed and parallel machine learning algorithms that use MapReduce-style operations. DrJAX leverages JAX's sharding mechanisms to enable native targeting of TPUs and state-of-the-art JAX runtimes, including Pathways. DrJAX embeds building blocks for MapReduce computations as primitives in JAX. This enables three key benefits. First, DrJAX computations can be translated directly to XLA HLO, enabling flexible integration with a wide array of ML training platforms. Second, DrJAX computations are fully differentiable. Last, DrJAX computations can be interpreted out to existing batch-processing compute systems, including traditional MapReduce systems like Apache Beam and cross-device compute systems like those powering federated learning applications. We show that DrJAX provides an easily programmable, performant, and scalable framework for parallelized algorithm development. DrJAX is available at \url{https://github.com/google-research/google-research/tree/master/drjax}.
