Table of Contents
Fetching ...

Worst-case generation via minimax optimization in Wasserstein space

Xiuyuan Cheng, Yao Xie, Linglingzhi Zhu, Yunqin Zhu

TL;DR

This work tackles worst-case sample generation under distribution shifts by formulating a minimax problem over Wasserstein space and recasting the inner maximization as a transport-map pushforward from a reference measure. It proposes a single-loop Gradient Descent-Ascent (GDA) algorithm with a neural transport map, providing convergence guarantees across various nonconvex regimes and enabling scalable, out-of-sample worst-case generation through L^2 transport-map matching. Theoretical results cover NC-PL, NC-SC, and NC-NC settings, complemented by practical particle-optimization and neural transport-map algorithms implemented on finite samples. Empirical validation on synthetic 2D data and image datasets (MNIST and CIFAR-10) demonstrates meaningful worst-case distributions and effective generalization via the learned transport map, highlighting the method’s robustness and scalability for stress-testing and robustness certification in high-dimensional settings.

Abstract

Worst-case generation plays a critical role in evaluating robustness and stress-testing systems under distribution shifts, in applications ranging from machine learning models to power grids and medical prediction systems. We develop a generative modeling framework for worst-case generation for a pre-specified risk, based on min-max optimization over continuous probability distributions, namely the Wasserstein space. Unlike traditional discrete distributionally robust optimization approaches, which often suffer from scalability issues, limited generalization, and costly worst-case inference, our framework exploits the Brenier theorem to characterize the least favorable (worst-case) distribution as the pushforward of a transport map from a continuous reference measure, enabling a continuous and expressive notion of risk-induced generation beyond classical discrete DRO formulations. Based on the min-max formulation, we propose a Gradient Descent Ascent (GDA)-type scheme that updates the decision model and the transport map in a single loop, establishing global convergence guarantees under mild regularity assumptions and possibly without convexity-concavity. We also propose to parameterize the transport map using a neural network that can be trained simultaneously with the GDA iterations by matching the transported training samples, thereby achieving a simulation-free approach. The efficiency of the proposed method as a risk-induced worst-case generator is validated by numerical experiments on synthetic and image data.

Worst-case generation via minimax optimization in Wasserstein space

TL;DR

This work tackles worst-case sample generation under distribution shifts by formulating a minimax problem over Wasserstein space and recasting the inner maximization as a transport-map pushforward from a reference measure. It proposes a single-loop Gradient Descent-Ascent (GDA) algorithm with a neural transport map, providing convergence guarantees across various nonconvex regimes and enabling scalable, out-of-sample worst-case generation through L^2 transport-map matching. Theoretical results cover NC-PL, NC-SC, and NC-NC settings, complemented by practical particle-optimization and neural transport-map algorithms implemented on finite samples. Empirical validation on synthetic 2D data and image datasets (MNIST and CIFAR-10) demonstrates meaningful worst-case distributions and effective generalization via the learned transport map, highlighting the method’s robustness and scalability for stress-testing and robustness certification in high-dimensional settings.

Abstract

Worst-case generation plays a critical role in evaluating robustness and stress-testing systems under distribution shifts, in applications ranging from machine learning models to power grids and medical prediction systems. We develop a generative modeling framework for worst-case generation for a pre-specified risk, based on min-max optimization over continuous probability distributions, namely the Wasserstein space. Unlike traditional discrete distributionally robust optimization approaches, which often suffer from scalability issues, limited generalization, and costly worst-case inference, our framework exploits the Brenier theorem to characterize the least favorable (worst-case) distribution as the pushforward of a transport map from a continuous reference measure, enabling a continuous and expressive notion of risk-induced generation beyond classical discrete DRO formulations. Based on the min-max formulation, we propose a Gradient Descent Ascent (GDA)-type scheme that updates the decision model and the transport map in a single loop, establishing global convergence guarantees under mild regularity assumptions and possibly without convexity-concavity. We also propose to parameterize the transport map using a neural network that can be trained simultaneously with the GDA iterations by matching the transported training samples, thereby achieving a simulation-free approach. The efficiency of the proposed method as a risk-induced worst-case generator is validated by numerical experiments on synthetic and image data.

Paper Structure

This paper contains 47 sections, 12 theorems, 165 equations, 12 figures, 2 tables, 1 algorithm.

Key Result

Lemma 2.1

Under Assumption assump:l0-smooth-ell, $L( \theta, T)$ is coordinate $l$-smooth on $\mathbb{R}^p \times L^2(P)$.

Figures (12)

  • Figure 1: $L^2$ regression loss on 2D data, GDA in the $\theta$-fast setting. Upper: The saddle point solution found by the algorithm, where open circles are original samples $x_i$, solid circles are converged sample locations $v_i = T(x_i)$, and the color bar indicates the loss $\ell(\theta_k, \cdot)$ evaluated on $v_i$ in the last iteration. Lower: Gradient norms of $\theta$ and $T$ computed on finite samples, see \ref{['eq:GN-on-batch']}, along the iterations by GDA.
  • Figure 2:
  • Figure 3:
  • Figure 4:
  • Figure 6: GDA with momentum on MNIST with different batch sizes $m$, where $k$ stands for the number of batches. $\gamma=8.0$, $\eta=\tau=0.01$, and momentum 0.7. (a) Gradient norm ${\rm GN}_\theta^k$, and (b) gradient norm ${\rm GN}_T^k$, computed on batches as defined in \ref{['eq:GN-on-batch']}. (c) Matching loss ${\cal L}$ for $k$-th batch (average over the $m/m'$ smaller batches, $m'=50$). All errors are shown in the log scale and subsampled at intervals of 20 iterations for better visualization.
  • ...and 7 more figures

Theorems & Definitions (29)

  • Definition 1
  • Remark 2.1
  • Lemma 2.1
  • Lemma 2.2
  • Lemma 2.3
  • Definition 2
  • Theorem 2.4: $T$ fast NC-PL
  • Remark 2.2: Small and large $\gamma$
  • Example 1: $L^2$ regression
  • Lemma 3.1
  • ...and 19 more