Table of Contents
Fetching ...

On the Wasserstein Convergence and Straightness of Rectified Flow

Vansh Bansal, Saptarshi Roy, Purnamrita Sarkar, Alessandro Rinaldo

TL;DR

This work analyzes Rectified Flow ($RF$) as a fast sampling paradigm that learns straight transport trajectories from noise to data by rectifying curved flows. It establishes a precise $W_2$ convergence bound between RF-generated distributions and the target that depends on discretization and drift-estimation error, and introduces new straightness metrics that explain when few discretization steps suffice. The authors prove existence and uniqueness of $1$-$RF$ under practical regularity conditions and connect $1$-$RF$ to Monge maps in several settings, including Gaussian-to-Gaussian and certain Gaussian mixtures, with extensive experiments on synthetic and real datasets supporting the theory. Overall, the paper provides rigorous theoretical grounding for the empirical observation that RF can yield straight, efficient sampling trajectories and Monge-map transports in many practical cases.

Abstract

Diffusion models have emerged as a powerful tool for image generation and denoising. Typically, generative models learn a trajectory between the starting noise distribution and the target data distribution. Recently Liu et al. (2023b) proposed Rectified Flow (RF), a generative model that aims to learn straight flow trajectories from noise to data using a sequence of convex optimization problems with close ties to optimal transport. If the trajectory is curved, one must use many Euler discretization steps or novel strategies, such as exponential integrators, to achieve a satisfactory generation quality. In contrast, RF has been shown to theoretically straighten the trajectory through successive rectifications, reducing the number of function evaluations (NFEs) while sampling. It has also been shown empirically that RF may improve the straightness in two rectifications if one can solve the underlying optimization problem within a sufficiently small error. In this paper, we make two contributions. First, we provide a theoretical analysis of the Wasserstein distance between the sampling distribution of RF and the target distribution. Our error rate is characterized by the number of discretization steps and a novel formulation of straightness stronger than that in the original work. Secondly, we present general conditions guaranteeing uniqueness and straightness of 1-RF, which is in line with previous empirical findings. As a byproduct of our analysis, we show that, in one dimension, RF started at the standard Gaussian distribution yields the Monge map. Additionally, we also present empirical results on both simulated and real datasets to validate our theoretical findings. The code is available at https://github.com/bansal-vansh/rectified-flow.

On the Wasserstein Convergence and Straightness of Rectified Flow

TL;DR

This work analyzes Rectified Flow () as a fast sampling paradigm that learns straight transport trajectories from noise to data by rectifying curved flows. It establishes a precise convergence bound between RF-generated distributions and the target that depends on discretization and drift-estimation error, and introduces new straightness metrics that explain when few discretization steps suffice. The authors prove existence and uniqueness of - under practical regularity conditions and connect - to Monge maps in several settings, including Gaussian-to-Gaussian and certain Gaussian mixtures, with extensive experiments on synthetic and real datasets supporting the theory. Overall, the paper provides rigorous theoretical grounding for the empirical observation that RF can yield straight, efficient sampling trajectories and Monge-map transports in many practical cases.

Abstract

Diffusion models have emerged as a powerful tool for image generation and denoising. Typically, generative models learn a trajectory between the starting noise distribution and the target data distribution. Recently Liu et al. (2023b) proposed Rectified Flow (RF), a generative model that aims to learn straight flow trajectories from noise to data using a sequence of convex optimization problems with close ties to optimal transport. If the trajectory is curved, one must use many Euler discretization steps or novel strategies, such as exponential integrators, to achieve a satisfactory generation quality. In contrast, RF has been shown to theoretically straighten the trajectory through successive rectifications, reducing the number of function evaluations (NFEs) while sampling. It has also been shown empirically that RF may improve the straightness in two rectifications if one can solve the underlying optimization problem within a sufficiently small error. In this paper, we make two contributions. First, we provide a theoretical analysis of the Wasserstein distance between the sampling distribution of RF and the target distribution. Our error rate is characterized by the number of discretization steps and a novel formulation of straightness stronger than that in the original work. Secondly, we present general conditions guaranteeing uniqueness and straightness of 1-RF, which is in line with previous empirical findings. As a byproduct of our analysis, we show that, in one dimension, RF started at the standard Gaussian distribution yields the Monge map. Additionally, we also present empirical results on both simulated and real datasets to validate our theoretical findings. The code is available at https://github.com/bansal-vansh/rectified-flow.

Paper Structure

This paper contains 51 sections, 16 theorems, 117 equations, 8 figures.

Key Result

Theorem 3.2

Let the condition of Assumption assmp: main assumptionassmp: lipschitz cond hold, and also assume that $\rho_1$ is absolutely continuous with respect to the Lebesgue measure in $\mathbb{R}^d$. Also, write $b(t)= \mathbb{E}_{X_t\sim \rho_t} \Vert v_t(X_t) - \widehat{v}_t(X_t)\Vert_2^2$ for $t \in

Figures (8)

  • Figure 1: Flow of $Z_t = Z_0 + (t, 50 N^{-2}\sin(2 \pi N t))^\top$ for different choices of $N$.
  • Figure 2: The above plots show the evolution of $0.5 * \lambda_{min}(J_t^{z_0} + J_t^{z_0\top})$ for 100 different samples $z_0 \sim N(0, I_2)$ under $\rho_1 = \sum_{i=1}^K \pi_i N(\mu_i, \sigma^2)$ for $K \in \{2, 3, 4\}$. The mean and variance parameters are as follows: (a) $\mu_1 = (5,1), \mu_2 = (-7, -2), \sigma = 6.5$ and $\pi_1 = 1-\pi_2 = 0.6$, (b) $\mu_1 = (1,2), \mu_2 = (2,0), \mu_3 = (-1, -2 ), \sigma = 2.5$ and $\pi_1 = \pi_2 = 0.4, \pi_3 = 0.2$, (c) $\mu_1 = (1,3), \mu_2 = (2,0), \mu_3 = (-1, -2), \mu_4 = (0, -2), \sigma = 3$ and $\pi_1 = 0.3, \pi_2 = 0.4, \pi_3 = 0.2, \pi_4 = 0.1$.
  • Figure 3: The above plots show the evolution of mean $\Vert J_t^{z_0} - J_t^{z_0 \top}\Vert_{op}$ and $0.5*\lambda_{min}(J_t^{z_0} + J_t^{z_0 \top})$ over 100 different random initial points $z_0\sim N(0, I_2)$ for $\rho_1 = \sum_{k = 1}^4 N(\mu_k, \sigma^2 I_2)$ with weights $\pi_k = k/10$ for $k \in \{1,2,3,4\}$, and $\mu_1 = (0,-2), \mu_2 = (-1,-2), \mu_3 = (1, 3) , \mu_4 = (2,0)$ and $\sigma = 0.1$.
  • Figure 4: (a) $W_2^2 (\widehat{\rho}_{\rm{data}}, \rho_1)$ vs $T$ (in log-log scale) for mixtures of Gaussians with varying components. (b) The straightness parameter $\gamma_{2,T}(\mathcal{Z})$ vs $T$ (in log-log scale) for the same respective distributions. (c) $W_2^2 (\widehat{\rho}_{\rm{data}}, \rho_1)$ vs $T$ (in log-log scale) for the second rectification on the Gaussian mixtures (d) shows the $W_2^2 (\widehat{\rho}_{\rm{data}}, \rho_1)$ vs $T$ for the FashionMNIST dataset with varying components. We observe that the straightness of the flow decreases with increasing number of mixture components.
  • Figure A.1: (a) shows the flow of the minimum, median, and maximum values of a set of points, initially distributed according to a standard Gaussian. (b) shows the flow of points from a standard Gaussian to a symmetric mixture of two Gaussians and the black line represents $y=0$.
  • ...and 3 more figures

Theorems & Definitions (28)

  • Definition 2.1
  • Theorem 3.2
  • Remark 1
  • Definition 3.3
  • Lemma 3.4
  • Example 1
  • Example 2
  • Theorem 3.5
  • Corollary 3.6
  • Definition 4.1
  • ...and 18 more