Table of Contents
Fetching ...

JaxWildfire: A GPU-Accelerated Wildfire Simulator for Reinforcement Learning

Ufuk Çakır, Victor-Alexandru Darvariu, Bruno Lacerda, Nick Hawes

TL;DR

JaxWildfire delivers a GPU-accelerated, differentiable cellular automata wildfire simulator implemented in JAX, designed to accelerate RL research by enabling rapid, vectorized simulations and gradient-based tuning of model parameters. The framework supports real-world data through WorldCover and DEM, and integrates with Gymnax for RL workflows. Empirical results show substantial speedups over existing CA implementations, feasible gradient-based calibration against real wildfire data, and successful training of a simple PPO agent for fire suppression in a synthetic environment. This work provides a practical, RL-ready benchmark for proactive wildfire management and further RL-enabled hazard management research.

Abstract

Artificial intelligence methods are increasingly being explored for managing wildfires and other natural hazards. In particular, reinforcement learning (RL) is a promising path towards improving outcomes in such uncertain decision-making scenarios and moving beyond reactive strategies. However, training RL agents requires many environment interactions, and the speed of existing wildfire simulators is a severely limiting factor. We introduce $\texttt{JaxWildfire}$, a simulator underpinned by a principled probabilistic fire spread model based on cellular automata. It is implemented in JAX and enables vectorized simulations using $\texttt{vmap}$, allowing high throughput of simulations on GPUs. We demonstrate that $\texttt{JaxWildfire}$ achieves 6-35x speedup over existing software and enables gradient-based optimization of simulator parameters. Furthermore, we show that $\texttt{JaxWildfire}$ can be used to train RL agents to learn wildfire suppression policies. Our work is an important step towards enabling the advancement of RL techniques for managing natural hazards.

JaxWildfire: A GPU-Accelerated Wildfire Simulator for Reinforcement Learning

TL;DR

JaxWildfire delivers a GPU-accelerated, differentiable cellular automata wildfire simulator implemented in JAX, designed to accelerate RL research by enabling rapid, vectorized simulations and gradient-based tuning of model parameters. The framework supports real-world data through WorldCover and DEM, and integrates with Gymnax for RL workflows. Empirical results show substantial speedups over existing CA implementations, feasible gradient-based calibration against real wildfire data, and successful training of a simple PPO agent for fire suppression in a synthetic environment. This work provides a practical, RL-ready benchmark for proactive wildfire management and further RL-enabled hazard management research.

Abstract

Artificial intelligence methods are increasingly being explored for managing wildfires and other natural hazards. In particular, reinforcement learning (RL) is a promising path towards improving outcomes in such uncertain decision-making scenarios and moving beyond reactive strategies. However, training RL agents requires many environment interactions, and the speed of existing wildfire simulators is a severely limiting factor. We introduce , a simulator underpinned by a principled probabilistic fire spread model based on cellular automata. It is implemented in JAX and enables vectorized simulations using , allowing high throughput of simulations on GPUs. We demonstrate that achieves 6-35x speedup over existing software and enables gradient-based optimization of simulator parameters. Furthermore, we show that can be used to train RL agents to learn wildfire suppression policies. Our work is an important step towards enabling the advancement of RL techniques for managing natural hazards.

Paper Structure

This paper contains 6 sections, 6 equations, 6 figures.

Figures (6)

  • Figure 1: JaxWildfire simulation in Cape Town, SA using ESA WorldCover data.
  • Figure 2: Fire spread visualization. Our fire spread mechanism based on cellular automata propagates fire directionally. The arrival potential is accumulated from neighboring cells and used to determine the ignition probability. The state of the cell at the next timestep is sampled from a Bernoulli distribution parameterized by the ignition probability.
  • Figure 3: Calibration on a real wildfire. Comparison of predicted fire spread of Bear 2020 wildfire xia2024data averaged over 5 seeds (variance was negligible). The panels show the spread of the fire as simulated by A) PyTorchFire after calibration; B)JaxWildfire after calibration; C) real historical fire perimeter. The simulators achieve comparable results measured by Intersection over Union, with a slight advantage for JaxWildfire.
  • Figure 4: Performance evaluation of JaxWildfire. Left: wall time (s) on GPU, showing JaxWildfire is an order of magnitude faster than PyTorchFire, the competing library. Right: throughput of JaxWildfire on GPU versus CPU, demonstrating that our framework benefits from increasing degrees of parallelism.
  • Figure 5: PPO training curve in a synthetic 20x20 environment. A reward of 10 is given for completely extinguishing the fire.
  • ...and 1 more figures