MPAX: Mathematical Programming in JAX
Haihao Lu, Zedong Peng, Jinwen Yang
TL;DR
MPAX introduces a JAX-native first-order solver for large-scale LPs and convex QPs that leverages XLA compilation, batching, and autodiff to integrate optimization directly into modern ML workflows. It implements two PDHG-based schemes—the restarted Reflected Halpern PDHG for LP ($r^2$HPDHG) and the Restarted Accelerated PDHG for QP (rAPDHG)—with diagonal preconditioning, adaptive restarts, feasibility polishing, and primal-weight updates. The framework supports across-hardware execution (CPU/GPU/TPU), true batched solving, distributed optimization via SPMD data sharding, and differentiable optimization through surrogate gradient techniques or unrolled differentiation. Empirical results show substantial GPU speedups, near-linear multi-GPU scaling for dense LPs, strong batched-solve performance, and competitive benchmark results relative to GPU-optimized solvers, all while enabling end-to-end differentiable decision-making in ML pipelines. MPAX is publicly available at https://github.com/MIT-Lu-Lab/MPAX.
Abstract
We present MPAX (Mathematical Programming in JAX), an open-source first-order solver for large-scale linear programming (LP) and convex quadratic programming (QP) built natively in JAX. The primary goal of MPAX is to exploit modern machine learning infrastructure for large-scale mathematical programming, while also providing advanced mathematical programming algorithms that are easy to integrate into machine learning workflows. MPAX implements two PDHG variants, r2HPDHG for LP and rAPDHG for QP, together with diagonal preconditioning, adaptive restarts, adaptive step sizes, primal-weight updates, infeasibility detection, and feasibility polishing. Leveraging JAX's compilation and parallelization ecosystem, MPAX provides across-hardware portability, batched solving, distributed optimization, and automatic differentiation. We evaluate MPAX on CPUs, NVIDIA GPUs, and Google TPUs, observing substantial GPU speedups over CPU baselines and competitive performance relative to GPU-based codebases on standard LP/QP benchmarks. Our numerical experiments further demonstrate MPAX's capabilities in high-throughput batched solving, near-linear multi-GPU scaling for dense LPs, and efficient end-to-end differentiable training. The solver is publicly available at https://github.com/MIT-Lu-Lab/MPAX.
