Table of Contents
Fetching ...

JaxMARL: Multi-Agent RL Environments and Algorithms in JAX

Alexander Rutherford, Benjamin Ellis, Matteo Gallici, Jonathan Cook, Andrei Lupu, Gardar Ingvarsson, Timon Willi, Ravi Hammond, Akbir Khan, Christian Schroeder de Witt, Alexandra Souly, Saptarashmi Bandyopadhyay, Mikayel Samvelyan, Minqi Jiang, Robert Tjarko Lange, Shimon Whiteson, Bruno Lacerda, Nick Hawes, Tim Rocktaschel, Chris Lu, Jakob Nicolaus Foerster

TL;DR

JaxMARL introduces an open-source, GPU-accelerated MARL library implemented in JAX, uniting a diverse set of environments (including two new suites SMAX and STORM) with reusable, high-performance baselines (IPPO, MAPPO, IQL, VDN, QMIX). The framework enables end-to-end GPU pipelines and vectorized training, delivering substantial speedups over CPU-based counterparts and addressing MARL's evaluation bottlenecks by enabling rapid, broad benchmarking. It also provides evaluation guidelines and demonstrates correctness across multiple environments, showing strong correspondence with established baselines while drastically improving throughput. While offering notable benefits, the work also discusses limitations and future directions, including handling off-policy methods and expanding novel environments to push MARL benchmarks forward.

Abstract

Benchmarks are crucial in the development of machine learning algorithms, with available environments significantly influencing reinforcement learning (RL) research. Traditionally, RL environments run on the CPU, which limits their scalability with typical academic compute. However, recent advancements in JAX have enabled the wider use of hardware acceleration, enabling massively parallel RL training pipelines and environments. While this has been successfully applied to single-agent RL, it has not yet been widely adopted for multi-agent scenarios. In this paper, we present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments and popular baseline algorithms. Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches, and up to 12500x when multiple training runs are vectorized. This enables efficient and thorough evaluations, potentially alleviating the evaluation crisis in the field. We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. The code is available at https://github.com/flairox/jaxmarl.

JaxMARL: Multi-Agent RL Environments and Algorithms in JAX

TL;DR

JaxMARL introduces an open-source, GPU-accelerated MARL library implemented in JAX, uniting a diverse set of environments (including two new suites SMAX and STORM) with reusable, high-performance baselines (IPPO, MAPPO, IQL, VDN, QMIX). The framework enables end-to-end GPU pipelines and vectorized training, delivering substantial speedups over CPU-based counterparts and addressing MARL's evaluation bottlenecks by enabling rapid, broad benchmarking. It also provides evaluation guidelines and demonstrates correctness across multiple environments, showing strong correspondence with established baselines while drastically improving throughput. While offering notable benefits, the work also discusses limitations and future directions, including handling off-policy methods and expanding novel environments to push MARL benchmarks forward.

Abstract

Benchmarks are crucial in the development of machine learning algorithms, with available environments significantly influencing reinforcement learning (RL) research. Traditionally, RL environments run on the CPU, which limits their scalability with typical academic compute. However, recent advancements in JAX have enabled the wider use of hardware acceleration, enabling massively parallel RL training pipelines and environments. While this has been successfully applied to single-agent RL, it has not yet been widely adopted for multi-agent scenarios. In this paper, we present JaxMARL, the first open-source, Python-based library that combines GPU-enabled efficiency with support for a large number of commonly used MARL environments and popular baseline algorithms. Our experiments show that, in terms of wall clock time, our JAX-based training pipeline is around 14 times faster than existing approaches, and up to 12500x when multiple training runs are vectorized. This enables efficient and thorough evaluations, potentially alleviating the evaluation crisis in the field. We also introduce and benchmark SMAX, a JAX-based approximate reimplementation of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. This not only enables GPU acceleration, but also provides a more flexible MARL environment, unlocking the potential for self-play, meta-learning, and other future applications in MARL. The code is available at https://github.com/flairox/jaxmarl.
Paper Structure (49 sections, 19 figures, 8 tables)

This paper contains 49 sections, 19 figures, 8 tables.

Figures (19)

  • Figure 1: JaxMARL environments. We provide JAX-based implementations of a wide range of customizable MARL environments, covering continuous and discrete dynamics, variable number of agents, full and partial observability, and cooperative, competitive and mixed-incentive settings.
  • Figure 2: Our philosophy. JaxMARL combines a wide range of environments with ease of use and evaluation speed.
  • Figure 3: Speed of training an RNN agent using IPPO on a multi-particle environment in JaxMARL compared to two popular MARL libraries, see the Appendix for details.
  • Figure 4: Normalised scores aggregated over SMAX, MPE and Overcooked. PPO shows a clear advantage.
  • Figure 5: JaxMARL speed benchmarking results. \ref{['fig:qlearning_wall']} compares JaxMARL's returns in MPE over wall clock time with PyMARL's when using Q-Learning algorithms. \ref{['fig:mpe_agent']} demonstrates JaxMARL algorithms' ability to train many seeds in parallel. The figure compares training time (on the x-axis) for a varying number of training runs (on the y-axis) training using QMIX on MPE. The red dotted represents the time taken to train a single agent with PyMARL. \ref{['fig:sc2_speedup']} illustrates the speedup of a JaxMARL IPPO training run using SMAX compared to PyMARL using SMAC across a varying number of environment rollout threads.
  • ...and 14 more figures