jFoF: GPU Cluster Finding with Gradient Propagation
Benjamin Horowitz, Adrian E. Bayer
TL;DR
The paper presents jFoF, a GPU-native Friends-of-Friends halo finder implemented in JAX that performs all neighbor searches, label propagation, and group construction on accelerators, eliminating host-device transfers. It introduces two CUDA-friendly neighbor-search strategies—the $k$-d tree and a linked-cell grid—achieving up to an order-of-magnitude speedup over CPU FoF implementations while preserving catalog fidelity. Beyond performance, jFoF enables differentiable halo finding using frozen assignment and REINFORCE-based topological optimization, including decorated frozen assignments for surrogate mass gradients, enabling end-to-end gradient-based optimization in cosmological pipelines. This work lays the groundwork for integrating differentiable halo catalogs with GPU-accelerated simulators, potentially enabling joint inference and subgrid-model calibration within fully differentiable cosmology workflows.
Abstract
We present jFoF, a fully GPU-native Friends-of-Friends (FoF) halo finder designed for both high-performance simulation analysis and differentiable modeling. Implemented in JAX, jFoF achieves end-to-end acceleration by performing all neighbor searches, label propagation, and group construction directly on GPUs, eliminating costly host--device transfers. We introduce two complementary neighbor-search strategies, a standard k-d tree and a novel linked-cell grid, and demonstrate that jFoF attains up to an order-of-magnitude speedup compared to optimized CPU implementations while maintaining consistent halo catalogs. Beyond performance, jFoF enables gradient propagation through discrete halo-finding operations via both frozen-assignment and topological optimization modes. Using a topological optimization approach via a REINFORCE-style estimator, our approach allows smooth optimization of halo connectivity and membership, bridging continuous simulation fields with discrete structure catalogs. These capabilities make jFoF a foundation for differentiable inference, enabling end-to-end, gradient-based optimization of structure formation models within GPU-accelerated astrophysical pipelines. We make our code publicly available at https://github.com/bhorowitz/jFOF/.
