Table of Contents
Fetching ...

Squared Wasserstein-2 Distance for Efficient Reconstruction of Stochastic Differential Equations

Mingtao Xia, Xiangting Li, Qijing Shen, Tom Chou

TL;DR

This work analyzes the squared $W_2$ distance between probability measures induced by solutions to stochastic differential equations and uses it to formulate loss functions for reconstructing SDEs from noisy data. It derives a fundamental bound linking $W_2(\boldsymbol{\mu}, \hat{\boldsymbol{\mu}})$ to errors in the drift and diffusion terms, and introduces finite-dimensional projections and a time-decoupled loss that are computationally efficient. Theoretical results (Theorems 1–4) establish convergence of the finite-dimensional and time-decoupled losses to the true $W_2$ as the time grid is refined. Numerical experiments across CIR, OU, and 2D geometric Brownian motion demonstrate that the proposed time-decoupled $W_2$ loss outperforms traditional losses (MSE, KL, MMD) and other SDE-reconstruction approaches in both accuracy and efficiency, while remaining robust to initial-condition uncertainty. The framework suggests promising extensions to high-dimensional SDEs and to processes with jumps (Lévy) and provides a principled, transport-based surrogate for inverse problems in stochastic dynamics.

Abstract

We provide an analysis of the squared Wasserstein-2 ($W_2$) distance between two probability distributions associated with two stochastic differential equations (SDEs). Based on this analysis, we propose the use of a squared $W_2$ distance-based loss functions in the \textit{reconstruction} of SDEs from noisy data. To demonstrate the practicality of our Wasserstein distance-based loss functions, we performed numerical experiments that demonstrate the efficiency of our method in reconstructing SDEs that arise across a number of applications.

Squared Wasserstein-2 Distance for Efficient Reconstruction of Stochastic Differential Equations

TL;DR

This work analyzes the squared distance between probability measures induced by solutions to stochastic differential equations and uses it to formulate loss functions for reconstructing SDEs from noisy data. It derives a fundamental bound linking to errors in the drift and diffusion terms, and introduces finite-dimensional projections and a time-decoupled loss that are computationally efficient. Theoretical results (Theorems 1–4) establish convergence of the finite-dimensional and time-decoupled losses to the true as the time grid is refined. Numerical experiments across CIR, OU, and 2D geometric Brownian motion demonstrate that the proposed time-decoupled loss outperforms traditional losses (MSE, KL, MMD) and other SDE-reconstruction approaches in both accuracy and efficiency, while remaining robust to initial-condition uncertainty. The framework suggests promising extensions to high-dimensional SDEs and to processes with jumps (Lévy) and provides a principled, transport-based surrogate for inverse problems in stochastic dynamics.

Abstract

We provide an analysis of the squared Wasserstein-2 () distance between two probability distributions associated with two stochastic differential equations (SDEs). Based on this analysis, we propose the use of a squared distance-based loss functions in the \textit{reconstruction} of SDEs from noisy data. To demonstrate the practicality of our Wasserstein distance-based loss functions, we performed numerical experiments that demonstrate the efficiency of our method in reconstructing SDEs that arise across a number of applications.
Paper Structure (16 sections, 3 theorems, 81 equations, 5 figures, 5 tables)

This paper contains 16 sections, 3 theorems, 81 equations, 5 figures, 5 tables.

Key Result

Theorem 1

If $\{X(t)\}_{t=0}^T, \{\hat{X}(t)\}_{t=0}^T$ have the same initial condition distribution and they are solutions to Eq. SDE_representation and Eq. approximate_sde in the univariate case ($d=1$ in Eq. sde_dimension), respectively, and the following conditions hold: then, where $\tilde{X}(t)$ satisfies and with $h$ defined as

Figures (5)

  • Figure 1: (a) Ground-truth trajectories. (b) Reconstructed trajectories from nSDE using MSE loss. (c) Reconstructed trajectories from nSDE using mean$^2$+variance loss. (d) Reconstructed trajectories from nSDE using the finite-time-point time-decoupled $W_2$ loss. (e) Reconstructed trajectories from nSDE using a max-log-likelihood loss yields the worst approximation.
  • Figure 2: (a) Ground-truth trajectories and reconstructed trajectories by nSDE using the finite-time-point time-decoupled squared $W_2$ loss with $\sigma_0 = 0.5$. (b-c) Errors with respect to the numbers of ground-truth trajectories for $\sigma_0 = 0.5$. (d) Comparison of the reconstructed $\hat{f}_{\Theta_1}(u), \hat{\sigma}_{\Theta_2}(u)$ to the ground-truth functions $f(u), \sigma(u)$ for $\sigma_0 = 0.5$. (e-f) Errors with respect to noise level $\sigma_{0}$ with 200 training samples. Legends for panels (c, e, f) are the same as the one in (b).
  • Figure 3: (a) Ground-truth and reconstructed trajectories using the squared $W_2$ loss Eq. \ref{['approximation']}. Black and red curves are ground-truth and reconstructed trajectories, respectively. Black and red arrows indicate $f(x,t)$ and the reconstructed $\hat{f}(x,t)$ at fixed $(x,t)$, respectively. (b) Relative errors in reconstructed $\hat{f}$ and $\hat{\sigma}$, repeated 10 times. Error bars show the standard deviation. (c) Resource consumption with respect to the number of training samples $N_{\rm samples}$. Memory usage is measured by torch profiler and represents peak memory usage during training. The legend in the panel (c) is the same as the one in (b).
  • Figure 4: (a) Black dots and red squares are the ground-truth $(X_1(2), X_2(2))$ and the reconstructed $(\hat{X}_1(2), \hat{X}_2(2))$ found using the rotated squared $W_2$ loss function, respectively. Black and red arrows indicate, respectively, the vectors ${\bm f}(X_1,X_2)$ and $\hat{{\bm f}}(X_1,X_2)$. (b) Relative errors of the reconstructed ${\bm f}$ and $\boldsymbol{\sigma}$. Error bars indicate the standard deviation across ten reconstructions. (c) Runtime of different loss functions with respect to $N_{\rm samples}$. (d) The decrease of different loss functions with respect to training epochs. The legend for the panel (d) is the same as the one in (c).
  • Figure 5: (a) The change in Eq. \ref{['time_discretize']} and Eq. \ref{['approximation']} when minimizing Eq. \ref{['time_discretize']} over training epochs. (b) The change in Eq. \ref{['time_discretize']} and Eq. \ref{['approximation']} when minimizing Eq. \ref{['approximation']} over training epochs.

Theorems & Definitions (8)

  • Definition 1
  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Example 1
  • Example 2
  • Example 3
  • Example 4