Table of Contents
Fetching ...

Hierarchically branched diffusion models leverage dataset structure for class-conditional generation

Alex M. Tseng, Max Shen, Tommaso Biancalani, Gabriele Scalia

TL;DR

This work tackles the limitation of flat, class-conditional diffusion by introducing hierarchically branched diffusion models that encode class hierarchies through branch points in diffusion time. The core idea is to learn separate reverse-diffusion heads for each branch while preserving a shared underlying diffusion process, enabling efficient extension to new classes, transmutation between classes, and interpretable generative intermediates. The authors define a branch-point discovery procedure, demonstrate multi-modal gains on MNIST, tabular letters, single-cell RNA-seq, and ZINC250K, and show that branched models achieve competitive or superior generation quality with notable advantages in continual learning, analogy-based generation, and interpretability. This framework broadens the applicability of diffusion models in scientific domains by leveraging intrinsic dataset structure and providing practical benefits for data discovery and hypothesis testing.

Abstract

Class-labeled datasets, particularly those common in scientific domains, are rife with internal structure, yet current class-conditional diffusion models ignore these relationships and implicitly diffuse on all classes in a flat fashion. To leverage this structure, we propose hierarchically branched diffusion models as a novel framework for class-conditional generation. Branched diffusion models rely on the same diffusion process as traditional models, but learn reverse diffusion separately for each branch of a hierarchy. We highlight several advantages of branched diffusion models over the current state-of-the-art methods for class-conditional diffusion, including extension to novel classes in a continual-learning setting, a more sophisticated form of analogy-based conditional generation (i.e. transmutation), and a novel interpretability into the generation process. We extensively evaluate branched diffusion models on several benchmark and large real-world scientific datasets spanning many data modalities.

Hierarchically branched diffusion models leverage dataset structure for class-conditional generation

TL;DR

This work tackles the limitation of flat, class-conditional diffusion by introducing hierarchically branched diffusion models that encode class hierarchies through branch points in diffusion time. The core idea is to learn separate reverse-diffusion heads for each branch while preserving a shared underlying diffusion process, enabling efficient extension to new classes, transmutation between classes, and interpretable generative intermediates. The authors define a branch-point discovery procedure, demonstrate multi-modal gains on MNIST, tabular letters, single-cell RNA-seq, and ZINC250K, and show that branched models achieve competitive or superior generation quality with notable advantages in continual learning, analogy-based generation, and interpretability. This framework broadens the applicability of diffusion models in scientific domains by leveraging intrinsic dataset structure and providing practical benefits for data discovery and hypothesis testing.

Abstract

Class-labeled datasets, particularly those common in scientific domains, are rife with internal structure, yet current class-conditional diffusion models ignore these relationships and implicitly diffuse on all classes in a flat fashion. To leverage this structure, we propose hierarchically branched diffusion models as a novel framework for class-conditional generation. Branched diffusion models rely on the same diffusion process as traditional models, but learn reverse diffusion separately for each branch of a hierarchy. We highlight several advantages of branched diffusion models over the current state-of-the-art methods for class-conditional diffusion, including extension to novel classes in a continual-learning setting, a more sophisticated form of analogy-based conditional generation (i.e. transmutation), and a novel interpretability into the generation process. We extensively evaluate branched diffusion models on several benchmark and large real-world scientific datasets spanning many data modalities.
Paper Structure (18 sections, 8 equations, 13 figures, 9 tables, 2 algorithms)

This paper contains 18 sections, 8 equations, 13 figures, 9 tables, 2 algorithms.

Figures (13)

  • Figure 1: Schematic of a branched diffusion model. a) After adding sufficient noise, similar classes become indistinguishable from each other at branch points. Branch points (purple dots) between all classes define a hierarchy. Although there is still only a single underlying diffusion process, the hierarchy separates diffusion time into branches, where each branch represents diffusion for a subset of classes and a subset of diffusion times. An example of one diffusion intermediate is highlighted in blue; this example is an MNIST digit that is a 4 or 9, and is at an intermediate diffusion time. b) A branched diffusion model is realized as a multi-task neural network (NN) that predicts reverse diffusion (one output task for each branch). The prediction path for the blue-highlighted MNIST digit in panel a) is also in blue. c) We show a progression of methods from traditional linear diffusion models to hierarchically branched models.
  • Figure 2: Extending a branched model to new classes. a) Schematic of the addition of a new digit class to an existing branched diffusion model on MNIST. The introduction of the new class is accomplished by adding a singular new branch (purple dotted line). b) Examples of MNIST digits generated from a branched diffusion model (above) and a label-guided (linear) diffusion model (below), before and after fine-tuning on the new class. For the label-guided model, we also show examples of digits after fine-tuning on the whole dataset. c) On the MNIST dataset (above) and the single-cell RNA-seq dataset (below), we show the FID (i.e. generative performance) of each class, before and after fine-tuning on the new class. For the label-guided models, we also show the FIDs after fine-tuning on the whole dataset.
  • Figure 3: Transmutation between classes. a) From our branched diffusion model trained on MNIST, we show examples of 4s transmuted to 9s (left), and 9s transmuted to 4s (right). We also show the diffusion intermediate $x_{b}$ at the branch point. b) On the letters dataset, we show the scatterplots of some feature values before and after transmutation from Vs to Ys (left), or Ys to Vs (right). For each of the 16 features, we correlate the feature value before versus after transmutation and show a histogram of the correlations over all 16 features in either transmutation direction. c) On single-cell RNA-seq data, we transmuted between CD16+ NK cells and classical monocytes, and show the distribution of several marker genes before and after transmutation. The left column shows marker genes of classical monocytes, and the right column shows marker genes of CD16+ NK cells. d) On molecules from ZINC250K, we transmuted between acyclic and cyclic molecules, and between non-halogenated and halogenated molecules.
  • Figure 4: Interpretable hybrids at branch points. a) From our branched model trained on MNIST, we show examples of hybrids between the digits classes 4 and 9 (left), and between the digit classes 1 and 7 (right). Each hybrid in the middle row is the reverse-diffusion starting point for both images above and below it. We applied a small amount of Gaussian smoothing to the hybrids for ease of viewing. b) Averaging over many samples, the aggregate hybrids at branch points show the collective characteristics that are shared between MNIST classes. c) From our branched model trained on tabular letters, we show the distribution of some features between two pairs of classes---O and X (left), and E and F (right)---and the distribution of that feature in the generated hybrids from the corresponding branch point.
  • Figure S1: Examples of generated MNIST images. We show (uncurated) images of MNIST digits generated by branched diffusion models. Since branched diffusion models naturally output each class separately, generation of individual classes does not require supplying labels or pretrained classifiers. We show a sample of digits generated from a continuous-time (score-matching) diffusion model Song2021, and a discrete-time diffusion model (denoising diffusion probabilistic model) Ho2020. Branched diffusion models for multi-class generation fit neatly into practically any diffusion-model framework.
  • ...and 8 more figures