JaxWildfire: A GPU-Accelerated Wildfire Simulator for Reinforcement Learning
Ufuk Çakır, Victor-Alexandru Darvariu, Bruno Lacerda, Nick Hawes
TL;DR
JaxWildfire delivers a GPU-accelerated, differentiable cellular automata wildfire simulator implemented in JAX, designed to accelerate RL research by enabling rapid, vectorized simulations and gradient-based tuning of model parameters. The framework supports real-world data through WorldCover and DEM, and integrates with Gymnax for RL workflows. Empirical results show substantial speedups over existing CA implementations, feasible gradient-based calibration against real wildfire data, and successful training of a simple PPO agent for fire suppression in a synthetic environment. This work provides a practical, RL-ready benchmark for proactive wildfire management and further RL-enabled hazard management research.
Abstract
Artificial intelligence methods are increasingly being explored for managing wildfires and other natural hazards. In particular, reinforcement learning (RL) is a promising path towards improving outcomes in such uncertain decision-making scenarios and moving beyond reactive strategies. However, training RL agents requires many environment interactions, and the speed of existing wildfire simulators is a severely limiting factor. We introduce $\texttt{JaxWildfire}$, a simulator underpinned by a principled probabilistic fire spread model based on cellular automata. It is implemented in JAX and enables vectorized simulations using $\texttt{vmap}$, allowing high throughput of simulations on GPUs. We demonstrate that $\texttt{JaxWildfire}$ achieves 6-35x speedup over existing software and enables gradient-based optimization of simulator parameters. Furthermore, we show that $\texttt{JaxWildfire}$ can be used to train RL agents to learn wildfire suppression policies. Our work is an important step towards enabling the advancement of RL techniques for managing natural hazards.
