Hierarchically branched diffusion models leverage dataset structure for class-conditional generation
Alex M. Tseng, Max Shen, Tommaso Biancalani, Gabriele Scalia
TL;DR
This work tackles the limitation of flat, class-conditional diffusion by introducing hierarchically branched diffusion models that encode class hierarchies through branch points in diffusion time. The core idea is to learn separate reverse-diffusion heads for each branch while preserving a shared underlying diffusion process, enabling efficient extension to new classes, transmutation between classes, and interpretable generative intermediates. The authors define a branch-point discovery procedure, demonstrate multi-modal gains on MNIST, tabular letters, single-cell RNA-seq, and ZINC250K, and show that branched models achieve competitive or superior generation quality with notable advantages in continual learning, analogy-based generation, and interpretability. This framework broadens the applicability of diffusion models in scientific domains by leveraging intrinsic dataset structure and providing practical benefits for data discovery and hypothesis testing.
Abstract
Class-labeled datasets, particularly those common in scientific domains, are rife with internal structure, yet current class-conditional diffusion models ignore these relationships and implicitly diffuse on all classes in a flat fashion. To leverage this structure, we propose hierarchically branched diffusion models as a novel framework for class-conditional generation. Branched diffusion models rely on the same diffusion process as traditional models, but learn reverse diffusion separately for each branch of a hierarchy. We highlight several advantages of branched diffusion models over the current state-of-the-art methods for class-conditional diffusion, including extension to novel classes in a continual-learning setting, a more sophisticated form of analogy-based conditional generation (i.e. transmutation), and a novel interpretability into the generation process. We extensively evaluate branched diffusion models on several benchmark and large real-world scientific datasets spanning many data modalities.
