ODE$_t$(ODE$_l$): Shortcutting the Time and the Length in Diffusion and Flow Models for Faster Sampling
Denis Gudovskiy, Wenzhao Zheng, Tomoyuki Okuno, Yohei Nakata, Kurt Keutzer
TL;DR
The paper tackles the high sampling cost of continuous normalizing flows and diffusion models by introducing ODE_t(ODE_l), which treats the inner network as a discretized ODE over depth (length $l$) while keeping the outer time ODE solver-agnostic. A length-consistency training objective, together with architectural rewiring (residuals and length-embedded blocks), enables dynamic depth during sampling without substantial overhead. Empirical results on CelebA-HQ-256 and ImageNet-256 show up to a $2\times$ reduction in latency and up to $2.8$ FID point improvements in high-quality regimes, with adaptive-step solvers further boosting performance. The approach is complementary to existing NFE minimization methods and is openly released at $github.com/gudovskiy/odelt$ to encourage broad adoption and further development.
Abstract
Continuous normalizing flows (CNFs) and diffusion models (DMs) generate high-quality data from a noise distribution. However, their sampling process demands multiple iterations to solve an ordinary differential equation (ODE) with high computational complexity. State-of-the-art methods focus on reducing the number of discrete time steps during sampling to improve efficiency. In this work, we explore a complementary direction in which the quality-complexity tradeoff can also be controlled in terms of the neural network length. We achieve this by rewiring the blocks in the transformer-based architecture to solve an inner discretized ODE w.r.t. its depth. Then, we apply a length consistency term during flow matching training, and as a result, the sampling can be performed with an arbitrary number of time steps and transformer blocks. Unlike others, our ODE$_t$(ODE$_l$) approach is solver-agnostic in time dimension and reduces both latency and, importantly, memory usage. CelebA-HQ and ImageNet generation experiments show a latency reduction of up to $2\times$ in the most efficient sampling mode, and FID improvement of up to $2.8$ points for high-quality sampling when applied to prior methods. We open-source our code and checkpoints at github.com/gudovskiy/odelt.
