Joint Distillation for Fast Likelihood Evaluation and Sampling in Flow-based Models
Xinyue Ai, Yutong He, Albert Gu, Ruslan Salakhutdinov, J Zico Kolter, Nicholas Matthew Boffi, Max Simchowitz
TL;DR
The paper addresses the bottleneck of expensive log-likelihood evaluation in flow-based and diffusion models by introducing Fast Flow Joint Distillation (F2D2), a modular framework that jointly distills the sampling trajectory and the cumulative divergence from a shared velocity field. By parameterizing a joint flow map and using separate heads for velocity and divergence, F2D2 enables accurate, few-step sampling and calibrated likelihood computation, compatible with existing shortcuts and MeanFlow architectures. The authors provide two instantiations (Shortcut-F2D2 and MeanFlow-F2D2), along with practical training strategies and a new inference-time technique called Maximum Likelihood Self-Guidance, which improves sample quality with minimal extra cost. Empirical results on CIFAR-10 and ImageNet-64×64 demonstrate calibrated log-likelihoods with small NFEs while preserving competitive sample quality, and a 2-step MeanFlow with self-guidance can outperform a 1024-step teacher, underscoring the practical impact of simultaneous fast likelihood and sampling.
Abstract
Log-likelihood evaluation enables important capabilities in generative models, including model comparison, certain fine-tuning objectives, and many downstream applications. Yet paradoxically, some of today's best generative models -- diffusion and flow-based models -- still require hundreds to thousands of neural function evaluations (NFEs) to compute a single likelihood. While recent distillation methods have successfully accelerated sampling to just a few steps, they achieve this at the cost of likelihood tractability: existing approaches either abandon likelihood computation entirely or still require expensive integration over full trajectories. We present fast flow joint distillation (F2D2), a framework that simultaneously reduces the number of NFEs required for both sampling and likelihood evaluation by two orders of magnitude. Our key insight is that in continuous normalizing flows, the coupled ODEs for sampling and likelihood are computed from a shared underlying velocity field, allowing us to jointly distill both the sampling trajectory and cumulative divergence using a single model. F2D2 is modular, compatible with existing flow-based few-step sampling models, and requires only an additional divergence prediction head. Experiments demonstrate F2D2's capability of achieving accurate log-likelihood with few-step evaluations while maintaining high sample quality, solving a long-standing computational bottleneck in flow-based generative models. As an application of our approach, we propose a lightweight self-guidance method that enables a 2-step MeanFlow model to outperform a 1024 step teacher model with only a single additional backward NFE.
