Designing a Conditional Prior Distribution for Flow-Based Generative Models
Noam Issachar, Mohammad Salama, Raanan Fattal, Sagie Benaim
TL;DR
This work tackles inefficiencies in conditional flow-based generation by introducing a condition-specific prior distribution (CPD). The CPD maps a conditioning signal to an average data-point in data space and uses a Gaussian Mixture Model over a latent representation to produce a conditioned prior $p_0(x_0|c)$ that better aligns with the target modes, reducing transport cost. Flow Matching from Conditional Prior Distribution (CPD-FM) then trains a velocity field to transport samples from this CPD to the conditional target, yielding shorter probability paths and lower global truncation error. Empirically, the method achieves faster convergence and superior FID, KID, and CLIP scores on ImageNet-64 and MS-COCO at low NFEs, outperforming CondOT, BatchOT, and DDPM baselines and enabling high-quality conditional generation with fewer sampling steps. The approach offers a flexible framework that can incorporate discrete or continuous conditioning and opens avenues for applying informative priors to broader conditional generation tasks.
Abstract
Flow-based generative models have recently shown impressive performance for conditional generation tasks, such as text-to-image generation. However, current methods transform a general unimodal noise distribution to a specific mode of the target data distribution. As such, every point in the initial source distribution can be mapped to every point in the target distribution, resulting in long average paths. To this end, in this work, we tap into a non-utilized property of conditional flow-based models: the ability to design a non-trivial prior distribution. Given an input condition, such as a text prompt, we first map it to a point lying in data space, representing an ``average" data point with the minimal average distance to all data points of the same conditional mode (e.g., class). We then utilize the flow matching formulation to map samples from a parametric distribution centered around this point to the conditional target distribution. Experimentally, our method significantly improves training times and generation efficiency (FID, KID and CLIP alignment scores) compared to baselines, producing high quality samples using fewer sampling steps.
