Table of Contents
Fetching ...

Physics-Informed Design of Input Convex Neural Networks for Consistency Optimal Transport Flow Matching

Fanghui Song, Zhongjian Wang, Jiebao Sun

TL;DR

This work introduces COFM, a physics-informed flow-matching framework for learning OT maps by optimizing a time-dependent convex potential parameterized with a time-conditioned PICNN. By combining a flow-matching objective with an HJ residual (or a path-consistency alternative), COFM enforces displacement-interpolating trajectories while avoiding inner optimization subproblems, enabling both one-step (Brenier-like) and multi-step sampling under the same learned potential. The method explicitly ties to OT duality, ensuring that minimization aligns with the OT objective, and demonstrates stability and scalability across toy problems, high-dimensional benchmarks, and real-valued image-to-image tasks in latent spaces. Empirically, COFM achieves state-of-the-art performance among flow-matching methods, with favorable training efficiency and robustness to increasing dimensionality, highlighting its practical utility for high-dimensional OT problems and generative modeling tasks.

Abstract

We propose a consistency model based on the optimal-transport flow. A physics-informed design of partially input-convex neural networks (PICNN) plays a central role in constructing the flow field that emulates the displacement interpolation. During the training stage, we couple the Hamilton-Jacobi (HJ) residual in the OT formulation with the original flow matching loss function. Our approach avoids inner optimization subproblems that are present in previous one-step OFM approaches. During the prediction stage, our approach supports both one-step (Brenier-map) and multi-step ODE sampling from the same learned potential, leveraging the straightness of the OT flow. We validate scalability and performance on standard OT benchmarks.

Physics-Informed Design of Input Convex Neural Networks for Consistency Optimal Transport Flow Matching

TL;DR

This work introduces COFM, a physics-informed flow-matching framework for learning OT maps by optimizing a time-dependent convex potential parameterized with a time-conditioned PICNN. By combining a flow-matching objective with an HJ residual (or a path-consistency alternative), COFM enforces displacement-interpolating trajectories while avoiding inner optimization subproblems, enabling both one-step (Brenier-like) and multi-step sampling under the same learned potential. The method explicitly ties to OT duality, ensuring that minimization aligns with the OT objective, and demonstrates stability and scalability across toy problems, high-dimensional benchmarks, and real-valued image-to-image tasks in latent spaces. Empirically, COFM achieves state-of-the-art performance among flow-matching methods, with favorable training efficiency and robustness to increasing dimensionality, highlighting its practical utility for high-dimensional OT problems and generative modeling tasks.

Abstract

We propose a consistency model based on the optimal-transport flow. A physics-informed design of partially input-convex neural networks (PICNN) plays a central role in constructing the flow field that emulates the displacement interpolation. During the training stage, we couple the Hamilton-Jacobi (HJ) residual in the OT formulation with the original flow matching loss function. Our approach avoids inner optimization subproblems that are present in previous one-step OFM approaches. During the prediction stage, our approach supports both one-step (Brenier-map) and multi-step ODE sampling from the same learned potential, leveraging the straightness of the OT flow. We validate scalability and performance on standard OT benchmarks.

Paper Structure

This paper contains 21 sections, 5 theorems, 73 equations, 4 figures, 1 table.

Key Result

Proposition 2.1

For any $x_0,x_1\in \mathbb{R}^d$ and a convex function $\Psi$, the following equality holds true,

Figures (4)

  • Figure 1: Time-dependent PICNN. Dashed paths denote values produced by MLPs (e.g.,$S_l(t)$, $\alpha (t)$). Solid links indicate an affine transform: blue are regular affine maps; orange are positive-weight affine maps; black are identity connections. Inputs are $(x,t)$.Each layer receives $W_x^{(l)}$ and a time bias $S_l(t)$; $W_z^{(l)}$ propagate $z^{(l-1)}$ forward under non-negativity. The last hidden output $z^{(L)}(t,x)$ enters the CEL, where it is gated by $(1-t)$ and combined with a time-dependent quadratic $\alpha(t)\|x\|^2$ (with $\alpha(t)$ generated by MLP$(t)$). One forward pass yields $\Psi^\theta$ then Autograd call yields the velocity used in training or sampling.
  • Figure 2: Effect of sampling steps. Qualitative performance of a fixed trained model with $N=1,2,10,100$ on the Gaussian to Eight-Gaussians setup in 2D. Five ($5$) realizations of the generation are included for illustration: from initial (triangle) to generation (square); $+$ denotes the intermediate steps of the ODE \ref{['eq:ode']} generation $(N>1)$ with uniform time discretization.
  • Figure 3: Comparison of solvers across different dimensions $D$. *Metrics are taken from OFM.
  • Figure 4: Unpaired image to image translation (male→female) using our COFM solver and OFM solver in the FFHQ $1024\times1024$ ALAE latent space. ($N=1$ of all the above results).

Theorems & Definitions (10)

  • Proposition 2.1
  • Lemma 3.1: Parametric Fenchel-Young equivalences
  • proof
  • Theorem 3.1: HJ consistency and OT duality
  • proof
  • Theorem 3.2: Dual-gap bound with FM and HJ residual
  • proof
  • Remark 3.1
  • Theorem 3.3
  • proof