Table of Contents
Fetching ...

minimax: Efficient Baselines for Autocurricula in JAX

Minqi Jiang, Michael Dennis, Edward Grefenstette, Tim Rocktäschel

TL;DR

This paper introduces minimax, a JAX-based library for fast autocurriculum research in unsupervised environment design (UED). By tensorizing environments and fully jitting the training loop, minimax enables hardware-accelerated, scalable experimentation, including a fast, fully-tensorized maze benchmark (AMaze). It presents new parallelized and multi-device baselines (e.g., PLR$^ ext{parallel}$, ACCEL$^ ext{parallel}$, S5 policies) that dramatically reduce wall time—often by orders of magnitude—while maintaining or exceeding prior performance on OOD benchmarks. The work demonstrates that simply scaling batch sizes and exploiting parallelism can yield substantial gains in zero-shot transfer, making autocurriculum research more practical for a broader set of labs. Overall, minimax offers a modular, extensible platform that accelerates discovery in UED and related autocurriculum methods, with real-world implications for rapid iteration and reproducibility.

Abstract

Unsupervised environment design (UED) is a form of automatic curriculum learning for training robust decision-making agents to zero-shot transfer into unseen environments. Such autocurricula have received much interest from the RL community. However, UED experiments, based on CPU rollouts and GPU model updates, have often required several weeks of training. This compute requirement is a major obstacle to rapid innovation for the field. This work introduces the minimax library for UED training on accelerated hardware. Using JAX to implement fully-tensorized environments and autocurriculum algorithms, minimax allows the entire training loop to be compiled for hardware acceleration. To provide a petri dish for rapid experimentation, minimax includes a tensorized grid-world based on MiniGrid, in addition to reusable abstractions for conducting autocurricula in procedurally-generated environments. With these components, minimax provides strong UED baselines, including new parallelized variants, which achieve over 120$\times$ speedups in wall time compared to previous implementations when training with equal batch sizes. The minimax library is available under the Apache 2.0 license at https://github.com/facebookresearch/minimax.

minimax: Efficient Baselines for Autocurricula in JAX

TL;DR

This paper introduces minimax, a JAX-based library for fast autocurriculum research in unsupervised environment design (UED). By tensorizing environments and fully jitting the training loop, minimax enables hardware-accelerated, scalable experimentation, including a fast, fully-tensorized maze benchmark (AMaze). It presents new parallelized and multi-device baselines (e.g., PLR, ACCEL, S5 policies) that dramatically reduce wall time—often by orders of magnitude—while maintaining or exceeding prior performance on OOD benchmarks. The work demonstrates that simply scaling batch sizes and exploiting parallelism can yield substantial gains in zero-shot transfer, making autocurriculum research more practical for a broader set of labs. Overall, minimax offers a modular, extensible platform that accelerates discovery in UED and related autocurriculum methods, with real-world implications for rapid iteration and reproducibility.

Abstract

Unsupervised environment design (UED) is a form of automatic curriculum learning for training robust decision-making agents to zero-shot transfer into unseen environments. Such autocurricula have received much interest from the RL community. However, UED experiments, based on CPU rollouts and GPU model updates, have often required several weeks of training. This compute requirement is a major obstacle to rapid innovation for the field. This work introduces the minimax library for UED training on accelerated hardware. Using JAX to implement fully-tensorized environments and autocurriculum algorithms, minimax allows the entire training loop to be compiled for hardware acceleration. To provide a petri dish for rapid experimentation, minimax includes a tensorized grid-world based on MiniGrid, in addition to reusable abstractions for conducting autocurricula in procedurally-generated environments. With these components, minimax provides strong UED baselines, including new parallelized variants, which achieve over 120 speedups in wall time compared to previous implementations when training with equal batch sizes. The minimax library is available under the Apache 2.0 license at https://github.com/facebookresearch/minimax.
Paper Structure (20 sections, 9 figures, 6 tables)

This paper contains 20 sections, 9 figures, 6 tables.

Figures (9)

  • Figure 1: Wall time speed-up factors achieved by minimax relative to PyTorch reference implementations with equal batch sizes (mean and std of 10 runs). Experiments that took 100+ hours now finish in < 3 hours on a single GPU. Here, +S5 indicates the use of an S5 policy, and +P, parallel DCD. (See Section \ref{['sec:baselines']} for details.)
  • Figure 2: A high-level overview of the minimax library. The training iteration logic is fully-jitted.
  • Figure 3: Example training and test environments in AMaze, a fully-tensorized maze environment in JAX.
  • Figure 4: Shortest path lengths of training mazes per method (mean and std of 10 runs).
  • Figure 5: Left: The sequence of operations in standard implementations of PLR$^\perp$ and ACCEL. Right: PLR$^\parallel$ and ACCEL$^\parallel$ reduce wall time by executing rollouts for new levels, replay levels, and mutated levels in parallel.
  • ...and 4 more figures