Table of Contents
Fetching ...

Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization

Jiajun Fan, Shuaike Shen, Chaoran Cheng, Yuxin Chen, Chumeng Liang, Ge Liu

TL;DR

This paper tackles the challenge of fine-tuning continuous flow-based generative models to align with arbitrary reward objectives without requiring reward gradients or filtered datasets. It introduces Online Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization (ORW-CFM-W2), an online RL framework that uses reward-weighted flow matching and a tractable $W_2$ distance bound to prevent policy collapse and preserve diversity. The authors provide theoretical results on induced data distributions, convergence behavior, and connections to KL-regularized RL, and they validate the approach across target image generation, compression, and text-image alignment tasks, including large-scale models like Stable Diffusion 3. The key contribution is a practical, computation-efficient, and theoretically grounded method that achieves optimal policy convergence while offering controllable trade-offs between reward optimization and diversity, enabling robust reward-driven fine-tuning of CNFs in diverse downstream tasks.

Abstract

Recent advancements in reinforcement learning (RL) have achieved great success in fine-tuning diffusion-based generative models. However, fine-tuning continuous flow-based generative models to align with arbitrary user-defined reward functions remains challenging, particularly due to issues such as policy collapse from overoptimization and the prohibitively high computational cost of likelihoods in continuous-time flows. In this paper, we propose an easy-to-use and theoretically sound RL fine-tuning method, which we term Online Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization (ORW-CFM-W2). Our method integrates RL into the flow matching framework to fine-tune generative models with arbitrary reward functions, without relying on gradients of rewards or filtered datasets. By introducing an online reward-weighting mechanism, our approach guides the model to prioritize high-reward regions in the data manifold. To prevent policy collapse and maintain diversity, we incorporate Wasserstein-2 (W2) distance regularization into our method and derive a tractable upper bound for it in flow matching, effectively balancing exploration and exploitation of policy optimization. We provide theoretical analyses to demonstrate the convergence properties and induced data distributions of our method, establishing connections with traditional RL algorithms featuring Kullback-Leibler (KL) regularization and offering a more comprehensive understanding of the underlying mechanisms and learning behavior of our approach. Extensive experiments on tasks including target image generation, image compression, and text-image alignment demonstrate the effectiveness of our method, where our method achieves optimal policy convergence while allowing controllable trade-offs between reward maximization and diversity preservation.

Online Reward-Weighted Fine-Tuning of Flow Matching with Wasserstein Regularization

TL;DR

This paper tackles the challenge of fine-tuning continuous flow-based generative models to align with arbitrary reward objectives without requiring reward gradients or filtered datasets. It introduces Online Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization (ORW-CFM-W2), an online RL framework that uses reward-weighted flow matching and a tractable distance bound to prevent policy collapse and preserve diversity. The authors provide theoretical results on induced data distributions, convergence behavior, and connections to KL-regularized RL, and they validate the approach across target image generation, compression, and text-image alignment tasks, including large-scale models like Stable Diffusion 3. The key contribution is a practical, computation-efficient, and theoretically grounded method that achieves optimal policy convergence while offering controllable trade-offs between reward optimization and diversity, enabling robust reward-driven fine-tuning of CNFs in diverse downstream tasks.

Abstract

Recent advancements in reinforcement learning (RL) have achieved great success in fine-tuning diffusion-based generative models. However, fine-tuning continuous flow-based generative models to align with arbitrary user-defined reward functions remains challenging, particularly due to issues such as policy collapse from overoptimization and the prohibitively high computational cost of likelihoods in continuous-time flows. In this paper, we propose an easy-to-use and theoretically sound RL fine-tuning method, which we term Online Reward-Weighted Conditional Flow Matching with Wasserstein-2 Regularization (ORW-CFM-W2). Our method integrates RL into the flow matching framework to fine-tune generative models with arbitrary reward functions, without relying on gradients of rewards or filtered datasets. By introducing an online reward-weighting mechanism, our approach guides the model to prioritize high-reward regions in the data manifold. To prevent policy collapse and maintain diversity, we incorporate Wasserstein-2 (W2) distance regularization into our method and derive a tractable upper bound for it in flow matching, effectively balancing exploration and exploitation of policy optimization. We provide theoretical analyses to demonstrate the convergence properties and induced data distributions of our method, establishing connections with traditional RL algorithms featuring Kullback-Leibler (KL) regularization and offering a more comprehensive understanding of the underlying mechanisms and learning behavior of our approach. Extensive experiments on tasks including target image generation, image compression, and text-image alignment demonstrate the effectiveness of our method, where our method achieves optimal policy convergence while allowing controllable trade-offs between reward maximization and diversity preservation.

Paper Structure

This paper contains 64 sections, 44 theorems, 180 equations, 14 figures, 1 table, 3 algorithms.

Key Result

Theorem 1

Let $w: \mathcal{X} \rightarrow[0, \infty)$ be a measurable weighting function such that $0<Z=$$\int_{\mathcal{X}} w\left(x_1\right) q\left(x_1\right) d x_1<\infty$ and $w(x_1) \propto r(x_1)$. Define the reward weighted CFM loss function: where $w\left(x_1\right)$ is a weighting function, $q\left(x_1\right)$ is the original data distribution (e.g., data distribution induced by well-learned pre-t

Figures (14)

  • Figure 1: A General Architecture of Our Method.
  • Figure 2: Learning curve and generated images in target image generation task with different $\tau$ while $\alpha=0$. As $\tau$ increases, the convergent policy becomes increasingly greedy, and the diversity decreases. $\tau=0$ remains the similar distribution as pre-trained model (See Theorem \ref{['theorem: exp case']}).
  • Figure 3: Learning curve and generated images in target image generation task with different $\alpha$ while $\tau=10$. From $\alpha=0$ to $\alpha=0.8$, the diversity of the convergent generative distribution increases without sacrificing too much performance. $\alpha=0$ collapse to a Delta Distribution as Lemma \ref{['lemma: limiting case']}.
  • Figure 4: Reward-Diversity Trade-off and W2 Distance Control via $\alpha$ in Image Compression Task ddpo of CIFAR-10 cifar10 with $\tau=1$. W2 distance is estimated by its upper bound (see App. \ref{['app sec: W2 distance']}). Each point in (e) and (f) corresponds to one group of experiments with $\alpha$ varying from 0 to 1. As $\alpha$ increases, the final reward decays while distance between the fine-tuned model and the reference model become closer, inducing a controllable fine-tuning process.
  • Figure 5: General Comparison of Different Fine-tuning Methods on SD3 via CLIP Rewards.
  • ...and 9 more figures

Theorems & Definitions (95)

  • Theorem 1
  • Theorem 2: Online Reward Weighted CFM
  • Lemma 1: Limiting Case
  • Corollary 1: Policy Collapse/Overoptimization
  • Definition 4.1: Wasserstein-2 Distance
  • Theorem 3: W2 Bound for Flow Matching
  • Theorem 4
  • Theorem 5
  • Theorem 6: Induced Data Distribution
  • proof : Proof of Theorem \ref{['theorem: Induced Data Distribution']}, from RL Perspective
  • ...and 85 more