Table of Contents
Fetching ...

Efficient Generative Modeling via Penalized Optimal Transport Network

Wenhui Sophia Lu, Chenyang Zhong, Wing Hung Wong

TL;DR

The paper tackles the instability and mode-collapse issues of traditional Wasserstein-based generative models for high-dimensional data by introducing the Marginally-Penalized Wasserstein (MPW) distance and the Penalized Optimal Transport Network (POTNet). By optimizing a primal MPW objective that combines joint transport with coordinate-wise marginal penalties, POTNet eliminates the need for a critic, supports mixed data types, and leverages fast marginal convergence to mitigate mode dropping and tail shrinkage. The authors establish non-asymptotic generalization bounds and demonstrate theoretical attenuation of Type I and II mode collapse, along with substantial empirical gains across synthetic benchmarks and real data—achieving accurate data structure capture and orders-of-magnitude speedups in sampling. The approach yields robust performance on tabular data, scalable image generation, and competitive inference tasks, highlighting the practical impact of marginal information in high-dimensional generative modeling.

Abstract

The generation of synthetic data with distributions that faithfully emulate the underlying data-generating mechanism holds paramount significance. Wasserstein Generative Adversarial Networks (WGANs) have emerged as a prominent tool for this task; however, due to the delicate equilibrium of the minimax formulation and the instability of Wasserstein distance in high dimensions, WGAN often manifests the pathological phenomenon of mode collapse. This results in generated samples that converge to a restricted set of outputs and fail to adequately capture the tail behaviors of the true distribution. Such limitations can lead to serious downstream consequences. To this end, we propose the Penalized Optimal Transport Network (POTNet), a versatile deep generative model based on the marginally-penalized Wasserstein (MPW) distance. Through the MPW distance, POTNet effectively leverages low-dimensional marginal information to guide the overall alignment of joint distributions. Furthermore, our primal-based framework enables direct evaluation of the MPW distance, thus eliminating the need for a critic network. This formulation circumvents training instabilities inherent in adversarial approaches and avoids the need for extensive parameter tuning. We derive a non-asymptotic bound on the generalization error of the MPW loss and establish convergence rates of the generative distribution learned by POTNet. Our theoretical analysis together with extensive empirical evaluations demonstrate the superior performance of POTNet in accurately capturing underlying data structures, including their tail behaviors and minor modalities. Moreover, our model achieves orders of magnitude speedup during the sampling stage compared to state-of-the-art alternatives, which enables computationally efficient large-scale synthetic data generation.

Efficient Generative Modeling via Penalized Optimal Transport Network

TL;DR

The paper tackles the instability and mode-collapse issues of traditional Wasserstein-based generative models for high-dimensional data by introducing the Marginally-Penalized Wasserstein (MPW) distance and the Penalized Optimal Transport Network (POTNet). By optimizing a primal MPW objective that combines joint transport with coordinate-wise marginal penalties, POTNet eliminates the need for a critic, supports mixed data types, and leverages fast marginal convergence to mitigate mode dropping and tail shrinkage. The authors establish non-asymptotic generalization bounds and demonstrate theoretical attenuation of Type I and II mode collapse, along with substantial empirical gains across synthetic benchmarks and real data—achieving accurate data structure capture and orders-of-magnitude speedups in sampling. The approach yields robust performance on tabular data, scalable image generation, and competitive inference tasks, highlighting the practical impact of marginal information in high-dimensional generative modeling.

Abstract

The generation of synthetic data with distributions that faithfully emulate the underlying data-generating mechanism holds paramount significance. Wasserstein Generative Adversarial Networks (WGANs) have emerged as a prominent tool for this task; however, due to the delicate equilibrium of the minimax formulation and the instability of Wasserstein distance in high dimensions, WGAN often manifests the pathological phenomenon of mode collapse. This results in generated samples that converge to a restricted set of outputs and fail to adequately capture the tail behaviors of the true distribution. Such limitations can lead to serious downstream consequences. To this end, we propose the Penalized Optimal Transport Network (POTNet), a versatile deep generative model based on the marginally-penalized Wasserstein (MPW) distance. Through the MPW distance, POTNet effectively leverages low-dimensional marginal information to guide the overall alignment of joint distributions. Furthermore, our primal-based framework enables direct evaluation of the MPW distance, thus eliminating the need for a critic network. This formulation circumvents training instabilities inherent in adversarial approaches and avoids the need for extensive parameter tuning. We derive a non-asymptotic bound on the generalization error of the MPW loss and establish convergence rates of the generative distribution learned by POTNet. Our theoretical analysis together with extensive empirical evaluations demonstrate the superior performance of POTNet in accurately capturing underlying data structures, including their tail behaviors and minor modalities. Moreover, our model achieves orders of magnitude speedup during the sampling stage compared to state-of-the-art alternatives, which enables computationally efficient large-scale synthetic data generation.
Paper Structure (56 sections, 9 theorems, 96 equations, 10 figures, 4 tables, 2 algorithms)

This paper contains 56 sections, 9 theorems, 96 equations, 10 figures, 4 tables, 2 algorithms.

Key Result

Theorem 4.1

Assume that the support of $P_X$ is contained in $B_M=\{\mathbf{x}\in\mathbb{R}^d:\|\mathbf{x}\|_2\leq M\}$, where $M$ is a positive deterministic constant. Then for any $\delta\in (0,1\slash 2)$, with probabiliy at least $1-2\delta$, where $\pmb{\lambda}$ is defined as in Definition def:mpw-dist, $\mathcal{F}_{\pmb{\lambda}}$ is the set of functions $\phi:\mathbb{R}^d\rightarrow\mathbb{R}$ such

Figures (10)

  • Figure 1: Performance of five methods for estimating the approximate posterior distribution of $\pmb{\phi}$. (a) Marginal distributions for each $\phi_j$, $j=1,\dots,5$, by column. Cyan: ground truth; orange: synthetic samples. (b) Bivariate density contour plots: $\phi_4$ vs $\phi_3$ (top panel) and $\phi_2$ vs $\phi_5$ (bottom panel). (c) Heatmap of absolute deviation of sample covariance matrix from estimated covariance matrix of the ground truth dataset (lighter color indicates smaller deviation). Bias: spectral norm of difference between synthetic and real covariance (lower is better). Condition number: $\lambda_{\mathrm{max}} / \lambda_{\mathrm{min}}$ of covariance matrix.
  • Figure 2: Bivariate contour plots of the first two dimensions of a 20D Gaussian mixture model with 3 components. WGAN shows type I (mode dampening) mode collapse while OT exhibit both type I and type II (support shrinkage) mode collapse.
  • Figure 3: Comparison of synthetic data generated by each method in 2D (top) and 3D (bottom) when the underlying manifold is complex and lower-dimensional.
  • Figure 4: A comparison of the bivariate joint distribution of Longitude and HouseAge for the California Housing dataset. Red lines: upper and lower bounds for contour lines of the original data. OT fails to adequately capture the variance of HouseAge.
  • Figure 5: Comparison of generated MNIST handwritten digits after 100 epochs of training using a convolutional network architecture. (a). POTNet; (b). WGAN; (c). OT.
  • ...and 5 more figures

Theorems & Definitions (30)

  • Definition 3.1: Marginal Distribution
  • Definition 3.2: Marginally-Penalized Wasserstein (MPW) Distance
  • Remark 3.1
  • Definition 3.3: Type I Mode Collapse
  • Definition 3.4: Type II Mode Collapse
  • Theorem 4.1
  • Remark 4.1
  • Theorem 4.2
  • Remark 4.2
  • Remark 4.3
  • ...and 20 more