EconoJax: A Fast & Scalable Economic Simulation in Jax
Koen Ponse, Aske Plaat, Niki van Stein, Thomas M. Moerland
TL;DR
This work tackles the computational bottlenecks of training reinforcement learning agents in multi-agent economic environments by introducing EconoJax, a GPU-accelerated economic simulator implemented entirely in JAX. EconoJax enables rapid, large-scale experiments (e.g., 100 agents) and demonstrates emergent real-world-like behaviors such as specialization and the productivity-equality tradeoff, including progressive tax schedules. The paper also evaluates multiple multi-agent training strategies and finds that centralized training yields comparable policy outcomes to independent training in larger action spaces, while significantly reducing computational demands. By open-sourcing the code, EconoJax provides a practical platform for rapid research into economic policy and multi-agent RL, enabling broader experimentation and exploration of realism and scalability.
Abstract
Accurate economic simulations often require many experimental runs, particularly when combined with reinforcement learning. Unfortunately, training reinforcement learning agents in multi-agent economic environments can be slow. This paper introduces EconoJax, a fast simulated economy, based on the AI economist. EconoJax, and its training pipeline, are completely written in JAX. This allows EconoJax to scale to large population sizes and perform large experiments, while keeping training times within minutes. Through experiments with populations of 100 agents, we show how real-world economic behavior emerges through training within 15 minutes, in contrast to previous work that required several days. We additionally perform experiments in varying sized action spaces to test if some multi-agent methods produce more diverse behavior compared to others. Here, our findings indicate no notable differences in produced behavior with different methods as is sometimes suggested in earlier works. To aid further research, we open-source EconoJax on Github.
