Table of Contents
Fetching ...

gfnx: Fast and Scalable Library for Generative Flow Networks in JAX

Daniil Tiapkin, Artem Agarkov, Nikita Morozov, Ian Maksimov, Askar Tsyganov, Timofei Gritsaev, Sergey Samsonov

TL;DR

gfnx addresses the need for scalable, reproducible benchmarks for Generative Flow Networks (GFlowNets) by delivering a JAX-based library with end-to-end, JIT-compiled environments and single-file baselines. It decouples environment logic, reward models, and training objectives, enabling flexible on-device execution and rapid experimentation across a diverse suite of benchmarks. The paper demonstrates substantial wall-clock speedups over PyTorch-based baselines (up to 80x on GPU and 55x on CPU) while maintaining sampling quality, and provides a standardized benchmarking framework to accelerate research in GFlowNets. Looking ahead, it highlights areas for extension such as continuous actions, non-acyclic domains, multi-objective optimization, additional baselines, and vectorized training to further broaden applicability.

Abstract

In this paper, we present gfnx, a fast and scalable package for training and evaluating Generative Flow Networks (GFlowNets) written in JAX. gfnx provides an extensive set of environments and metrics for benchmarking, accompanied with single-file implementations of core objectives for training GFlowNets. We include synthetic hypergrids, multiple sequence generation environments with various editing regimes and particular reward designs for molecular generation, phylogenetic tree construction, Bayesian structure learning, and sampling from the Ising model energy. Across different tasks, gfnx achieves significant wall-clock speedups compared to Pytorch-based benchmarks (such as torchgfn library) and author implementations. For example, gfnx achieves up to 55 times speedup on CPU-based sequence generation environments, and up to 80 times speedup with the GPU-based Bayesian network structure learning setup. Our package provides a diverse set of benchmarks and aims to standardize empirical evaluation and accelerate research and applications of GFlowNets. The library is available on GitHub (https://github.com/d-tiapkin/gfnx) and on pypi (https://pypi.org/project/gfnx/). Documentation is available on https://gfnx.readthedocs.io.

gfnx: Fast and Scalable Library for Generative Flow Networks in JAX

TL;DR

gfnx addresses the need for scalable, reproducible benchmarks for Generative Flow Networks (GFlowNets) by delivering a JAX-based library with end-to-end, JIT-compiled environments and single-file baselines. It decouples environment logic, reward models, and training objectives, enabling flexible on-device execution and rapid experimentation across a diverse suite of benchmarks. The paper demonstrates substantial wall-clock speedups over PyTorch-based baselines (up to 80x on GPU and 55x on CPU) while maintaining sampling quality, and provides a standardized benchmarking framework to accelerate research in GFlowNets. Looking ahead, it highlights areas for extension such as continuous actions, non-acyclic domains, multi-objective optimization, additional baselines, and vectorized training to further broaden applicability.

Abstract

In this paper, we present gfnx, a fast and scalable package for training and evaluating Generative Flow Networks (GFlowNets) written in JAX. gfnx provides an extensive set of environments and metrics for benchmarking, accompanied with single-file implementations of core objectives for training GFlowNets. We include synthetic hypergrids, multiple sequence generation environments with various editing regimes and particular reward designs for molecular generation, phylogenetic tree construction, Bayesian structure learning, and sampling from the Ising model energy. Across different tasks, gfnx achieves significant wall-clock speedups compared to Pytorch-based benchmarks (such as torchgfn library) and author implementations. For example, gfnx achieves up to 55 times speedup on CPU-based sequence generation environments, and up to 80 times speedup with the GPU-based Bayesian network structure learning setup. Our package provides a diverse set of benchmarks and aims to standardize empirical evaluation and accelerate research and applications of GFlowNets. The library is available on GitHub (https://github.com/d-tiapkin/gfnx) and on pypi (https://pypi.org/project/gfnx/). Documentation is available on https://gfnx.readthedocs.io.

Paper Structure

This paper contains 27 sections, 29 equations, 7 figures, 8 tables.

Figures (7)

  • Figure 1: Visualization of GFlowNet environments. The first figure illustrates four sample trajectories, each represented by distinct colored paths from initial to terminal states. The second figure illustrates the sequential construction of a bit string; each row corresponds to a different environment state: the first row represents the empty string (initial state), the last row the complete string (final state), and intermediate rows display the token-by-token generation process.
  • Figure 2: Total variation between true reward and empirical sample distributions versus total training time (in seconds). The same number of training iterations was used for different implementations. We use torch implementation from lahlou2023torchgfn. These experiments were performed on CPU.
  • Figure 3: Comparison of gfnx and torch implementation tiapkin2024generative on the bit sequence generation task ($n = 120, k = 8$). Comparison of the performance in terms of the Pearson correlation coefficient between the terminating state log-probability and the log-reward on $7200$ randomly sampled bit sequences (higher better). Curves were smoothed using a moving average for visual clarity. Experiments were performed on GPU using TB and DB objectives.
  • Figure 4: Total variation between true reward and empirical sample distributions versus total training time (in seconds) on TFBind8 and QM9 environments. We use torch implementation from shen2023towards. Experiments were performed on CPU using the TB objective.
  • Figure 5: Top-100 reward and diversity versus total training time (in seconds) on AMP environment. We use torch implementation from jain2022biological. The experiment was performed on GPU using the TB objective.
  • ...and 2 more figures