Table of Contents
Fetching ...

Schrödinger bridge based deep conditional generative learning

Hanwen Huang

TL;DR

Even though the novel Schrodinger bridge based deep generative method for learning conditional distributions does not directly provide the conditional density estimation, the samples generated by this method exhibit higher quality compared to those obtained by several existing methods.

Abstract

Conditional generative models represent a significant advancement in the field of machine learning, allowing for the controlled synthesis of data by incorporating additional information into the generation process. In this work we introduce a novel Schrödinger bridge based deep generative method for learning conditional distributions. We start from a unit-time diffusion process governed by a stochastic differential equation (SDE) that transforms a fixed point at time $0$ into a desired target conditional distribution at time $1$. For effective implementation, we discretize the SDE with Euler-Maruyama method where we estimate the drift term nonparametrically using a deep neural network. We apply our method to both low-dimensional and high-dimensional conditional generation problems. The numerical studies demonstrate that though our method does not directly provide the conditional density estimation, the samples generated by this method exhibit higher quality compared to those obtained by several existing methods. Moreover, the generated samples can be effectively utilized to estimate the conditional density and related statistical quantities, such as conditional mean and conditional standard deviation.

Schrödinger bridge based deep conditional generative learning

TL;DR

Even though the novel Schrodinger bridge based deep generative method for learning conditional distributions does not directly provide the conditional density estimation, the samples generated by this method exhibit higher quality compared to those obtained by several existing methods.

Abstract

Conditional generative models represent a significant advancement in the field of machine learning, allowing for the controlled synthesis of data by incorporating additional information into the generation process. In this work we introduce a novel Schrödinger bridge based deep generative method for learning conditional distributions. We start from a unit-time diffusion process governed by a stochastic differential equation (SDE) that transforms a fixed point at time into a desired target conditional distribution at time . For effective implementation, we discretize the SDE with Euler-Maruyama method where we estimate the drift term nonparametrically using a deep neural network. We apply our method to both low-dimensional and high-dimensional conditional generation problems. The numerical studies demonstrate that though our method does not directly provide the conditional density estimation, the samples generated by this method exhibit higher quality compared to those obtained by several existing methods. Moreover, the generated samples can be effectively utilized to estimate the conditional density and related statistical quantities, such as conditional mean and conditional standard deviation.
Paper Structure (20 sections, 2 theorems, 40 equations, 4 figures, 2 tables, 2 algorithms)

This paper contains 20 sections, 2 theorems, 40 equations, 4 figures, 2 tables, 2 algorithms.

Key Result

Proposition 1

For reference SDE (ref1), the drift term ${\bf u}^\star({\bf x}_t,t)$ in (sde) is a time varying vector field ${\bf u}({\bf x}_t,t)$ that minimizes the following quadratic objective function where ${\bf Q}=[t\sim{\cal U}(0,1)]\otimes\mu_1({\bf x}_1)\otimes\pi({\bf x}_t|{\bf x}_1)$, and the conditional distribution $\pi({\bf x}_t|{\bf x}_1)$ is defined through $\pi({\bf x}_t|{\bf x}_1)\sim N(\hbox

Figures (4)

  • Figure 1: Sampling performance of SBCG on three two-dimensional datasets generated from examples (\ref{['example1']}). The conditional distributions of $x$ are sampled given $z=-1.2$ (red), $z=0$ (green), and $z=1.2$ (blue). The solid curves represent the corresponding true densities. From left to right: Example 1, Example 2, and Example 3.
  • Figure 2: Scatter plots of joint distributions generated by SBCG and the ground truth. From left to right: checkerboard, moons, pinwheel, and swissroll. Bottom: Visualization of the samples for synthetic problems from ground truth. Top: Visualization of the samples generated with the proposed SBCG algorithm.
  • Figure 3: MNIST dataset: real images (left panel) and generated images given the labels (right panel).
  • Figure 4: Reconstructed images given partial image in MNIST dataset. The first column in each panel consists of the true images, the second column in each panel consists of the associated conditions with part of image covered, the other columns give the constructed images. In the left panel, the right lower 1/4 of the image is given; in the middle panel, the right 1/2 of the image is given; in the right panel, 3/4 of the image is given

Theorems & Definitions (4)

  • Proposition 1
  • Proposition 2
  • proof
  • proof