Table of Contents
Fetching ...

Designing a Conditional Prior Distribution for Flow-Based Generative Models

Noam Issachar, Mohammad Salama, Raanan Fattal, Sagie Benaim

TL;DR

This work tackles inefficiencies in conditional flow-based generation by introducing a condition-specific prior distribution (CPD). The CPD maps a conditioning signal to an average data-point in data space and uses a Gaussian Mixture Model over a latent representation to produce a conditioned prior $p_0(x_0|c)$ that better aligns with the target modes, reducing transport cost. Flow Matching from Conditional Prior Distribution (CPD-FM) then trains a velocity field to transport samples from this CPD to the conditional target, yielding shorter probability paths and lower global truncation error. Empirically, the method achieves faster convergence and superior FID, KID, and CLIP scores on ImageNet-64 and MS-COCO at low NFEs, outperforming CondOT, BatchOT, and DDPM baselines and enabling high-quality conditional generation with fewer sampling steps. The approach offers a flexible framework that can incorporate discrete or continuous conditioning and opens avenues for applying informative priors to broader conditional generation tasks.

Abstract

Flow-based generative models have recently shown impressive performance for conditional generation tasks, such as text-to-image generation. However, current methods transform a general unimodal noise distribution to a specific mode of the target data distribution. As such, every point in the initial source distribution can be mapped to every point in the target distribution, resulting in long average paths. To this end, in this work, we tap into a non-utilized property of conditional flow-based models: the ability to design a non-trivial prior distribution. Given an input condition, such as a text prompt, we first map it to a point lying in data space, representing an ``average" data point with the minimal average distance to all data points of the same conditional mode (e.g., class). We then utilize the flow matching formulation to map samples from a parametric distribution centered around this point to the conditional target distribution. Experimentally, our method significantly improves training times and generation efficiency (FID, KID and CLIP alignment scores) compared to baselines, producing high quality samples using fewer sampling steps.

Designing a Conditional Prior Distribution for Flow-Based Generative Models

TL;DR

This work tackles inefficiencies in conditional flow-based generation by introducing a condition-specific prior distribution (CPD). The CPD maps a conditioning signal to an average data-point in data space and uses a Gaussian Mixture Model over a latent representation to produce a conditioned prior that better aligns with the target modes, reducing transport cost. Flow Matching from Conditional Prior Distribution (CPD-FM) then trains a velocity field to transport samples from this CPD to the conditional target, yielding shorter probability paths and lower global truncation error. Empirically, the method achieves faster convergence and superior FID, KID, and CLIP scores on ImageNet-64 and MS-COCO at low NFEs, outperforming CondOT, BatchOT, and DDPM baselines and enabling high-quality conditional generation with fewer sampling steps. The approach offers a flexible framework that can incorporate discrete or continuous conditioning and opens avenues for applying informative priors to broader conditional generation tasks.

Abstract

Flow-based generative models have recently shown impressive performance for conditional generation tasks, such as text-to-image generation. However, current methods transform a general unimodal noise distribution to a specific mode of the target data distribution. As such, every point in the initial source distribution can be mapped to every point in the target distribution, resulting in long average paths. To this end, in this work, we tap into a non-utilized property of conditional flow-based models: the ability to design a non-trivial prior distribution. Given an input condition, such as a text prompt, we first map it to a point lying in data space, representing an ``average" data point with the minimal average distance to all data points of the same conditional mode (e.g., class). We then utilize the flow matching formulation to map samples from a parametric distribution centered around this point to the conditional target distribution. Experimentally, our method significantly improves training times and generation efficiency (FID, KID and CLIP alignment scores) compared to baselines, producing high quality samples using fewer sampling steps.

Paper Structure

This paper contains 18 sections, 23 equations, 9 figures, 4 tables.

Figures (9)

  • Figure 1: An illustration of our approach. The LHS illustrates the standard flow matching paradigm, where every sample in the source Gaussian distribution (shown as a circular point) can be mapped to every sample in the conditional target mode (shown as a cross point), where each class samples are shown in a different color. In contrast, our method, shown on the RHS, constructs a class-specific conditional distribution as a source prior distribution. Each sample in the source distribution is, on average, closer to its corresponding sample in the target mode.
  • Figure 2: Trajectory illustration. A toy example illustrating the trajectory from the source to the target distribution for our method and conditional flow matching using optimal transport (CondOT).
  • Figure 3: (a) NFE convergence illustration. A toy example illustrating convergence to the target distribution at different NFEs, for our method, compared to CondOT. (b). Generalization illustration. A toy example illustrating the generalization capabilities. LHS: Source prior and target samples for training classes RHS: As for LHS, but for test classes.
  • Figure 4: Multi-modal classes. A toy example illustrating multi-modal classes with intersections in the prior. Each color represents a class (class A or B), with samples as points and the prior distribution as contour lines. (a) shows a standard Gaussian prior (in black), while (b) and (c) show class-specific priors. While the mean each class falls on samples from the other class, our method results in an improved MMD score.
  • Figure 5: Numerical evaluation.(a) We compare our method to class conditional flow matching using optimal transport paths (CondOT) lipman2022flow, BatchOT pooladian2023multisample, and DDPM Ho2020DDPM, on the ImageNet-64 dataset. We consider the FID score (LHS), KID score (Middle) and CLIP score (RHS). (b). As in (a) but for text-to-image generation on the MS-COCO dataset. As can be seen our method exhibit significant improvement per NFE, especially for low NFEs. For example, for 15 NFEs, on ImageNet-64 and MS-COCO we get FID of 13.62 and FID of 18.05 respectively, while baselines do not surpass FID of 16.10 and FID of 28.32 respectively for the same NFEs. We consider up to 40 NFE steps and note that DDPM converges to a superior result given more steps.
  • ...and 4 more figures