Table of Contents
Fetching ...

Implicit Bias of the JKO Scheme

Peter Halmos, Boris Hanin

TL;DR

The paper analyzes the implicit bias of the Jordan–Kinderlehrer–Otto (JKO) scheme for Wasserstein gradient flows. It proves a second-order modification of the objective, $J^\eta(\rho)=J(\rho)-\frac{\eta}{4}|\partial J(\rho)|^2$, so that the Wasserstein gradient flow on $J^\eta$ matches the JKO updates up to $O(\eta^2)$. This reveals a family of implicit biases: for entropy the bias is Fisher information, for KL divergence the Hyvärinen divergence, and for Riemannian gradient flows a kinetic-energy–type penalty, with extensions to free-energy/Langevin dynamics and to the Riemannian setting. The authors develop a BEA-based strategy to identify the modified flow and validate the theory through numerical experiments in Bures–Wasserstein space and one-dimensional quartic KL scenarios, demonstrating improved accuracy and stability of the JKO-Flow. These results provide a deeper understanding of the geometry of JKO updates and offer practical implications for stable discretizations of gradient flows on manifolds and in stochastic sampling.

Abstract

Wasserstein gradient flow provides a general framework for minimizing an energy functional $J$ over the space of probability measures on a Riemannian manifold $(M,g)$. Its canonical time-discretization, the Jordan-Kinderlehrer-Otto (JKO) scheme, produces for any step size $η>0$ a sequence of probability distributions $ρ_k^η$ that approximate to first order in $η$ Wasserstein gradient flow on $J$. But the JKO scheme also has many other remarkable properties not shared by other first order integrators, e.g. it preserves energy dissipation and exhibits unconditional stability for $λ$-geodesically convex functionals $J$. To better understand the JKO scheme we characterize its implicit bias at second order in $η$. We show that $ρ_k^η$ are approximated to order $η^2$ by Wasserstein gradient flow on a modified energy \[ J^η(ρ) = J(ρ) - \fracη{4}\int_M \Big\lVert \nabla_g \frac{δJ}{δρ} (ρ) \Big\rVert_{2}^{2} \,ρ(dx), \] obtained by subtracting from $J$ the squared metric curvature of $J$ times $η/4$. The JKO scheme therefore adds at second order in $η$ a deceleration in directions where the metric curvature of $J$ is rapidly changing. This corresponds to canonical implicit biases for common functionals: for entropy the implicit bias is the Fisher information, for KL-divergence it is the Fisher-Hyv{ä}rinen divergence, and for Riemannian gradient descent it is the kinetic energy in the metric $g$. To understand the differences between minimizing $J$ and $J^η$ we study JKO-Flow, Wasserstein gradient flow on $J^η$, in several simple numerical examples. These include exactly solvable Langevin dynamics on the Bures-Wasserstein space and Langevin sampling from a quartic potential in 1D.

Implicit Bias of the JKO Scheme

TL;DR

The paper analyzes the implicit bias of the Jordan–Kinderlehrer–Otto (JKO) scheme for Wasserstein gradient flows. It proves a second-order modification of the objective, , so that the Wasserstein gradient flow on matches the JKO updates up to . This reveals a family of implicit biases: for entropy the bias is Fisher information, for KL divergence the Hyvärinen divergence, and for Riemannian gradient flows a kinetic-energy–type penalty, with extensions to free-energy/Langevin dynamics and to the Riemannian setting. The authors develop a BEA-based strategy to identify the modified flow and validate the theory through numerical experiments in Bures–Wasserstein space and one-dimensional quartic KL scenarios, demonstrating improved accuracy and stability of the JKO-Flow. These results provide a deeper understanding of the geometry of JKO updates and offer practical implications for stable discretizations of gradient flows on manifolds and in stochastic sampling.

Abstract

Wasserstein gradient flow provides a general framework for minimizing an energy functional over the space of probability measures on a Riemannian manifold . Its canonical time-discretization, the Jordan-Kinderlehrer-Otto (JKO) scheme, produces for any step size a sequence of probability distributions that approximate to first order in Wasserstein gradient flow on . But the JKO scheme also has many other remarkable properties not shared by other first order integrators, e.g. it preserves energy dissipation and exhibits unconditional stability for -geodesically convex functionals . To better understand the JKO scheme we characterize its implicit bias at second order in . We show that are approximated to order by Wasserstein gradient flow on a modified energy obtained by subtracting from the squared metric curvature of times . The JKO scheme therefore adds at second order in a deceleration in directions where the metric curvature of is rapidly changing. This corresponds to canonical implicit biases for common functionals: for entropy the implicit bias is the Fisher information, for KL-divergence it is the Fisher-Hyv{ä}rinen divergence, and for Riemannian gradient descent it is the kinetic energy in the metric . To understand the differences between minimizing and we study JKO-Flow, Wasserstein gradient flow on , in several simple numerical examples. These include exactly solvable Langevin dynamics on the Bures-Wasserstein space and Langevin sampling from a quartic potential in 1D.

Paper Structure

This paper contains 24 sections, 14 theorems, 308 equations, 4 figures, 1 table.

Key Result

Theorem 1

Suppose $J$ is proper, l.s.c., coercive, and $\lambda$-geodesically convex along generalized geodesics. Then the JKO Scheme is well-posed and stable, with piecewise interpolation $\rho_{\lfloor t/\eta \rfloor}^{\mathrm{JKO}}$. Moreover, the Wasserstein gradient flow $(\rho_{t})_{t \ge 0}$ defined by exists, is unique, and satisfies for any $T>0$

Figures (4)

  • Figure 1: Wasserstein distance of $(\bm{\mu}, \Sigma)$ to analytic $(\bm{\mu}_{\mathrm{JKO}}, \Sigma_{\mathrm{JKO}})$ from Halder2017 of Wasserstein gradient flow and the modified second-order JKO-flow (Top left). Plot of mean and $2\sigma$-covariance isocontours of an example rotational flow (Top right). Mean (Bottom left) and Covariance (Bottom right) error of Wasserstein gradient flow and second-order modified flow to JKO.
  • Figure 2: Densities of one step of forward Euler discretization of Wasserstein gradient flow on $J=\mathrm{KL}(\rho||\pi)$ with $-\log \pi(x) = U(x)=x^2/2+ x^4/4+\text{const}$. Comparison with densities after one step of forward Euler on $J^\eta$ with $\eta\in \left\{0.1, 0.5\right\}$. Initial condition in all cases is the standard Gaussian. Theory predicts that the density obtained using $J^\eta$ is smooth if and only if $\eta > .3$.
  • Figure 3: Comparison of 1 FE step of size $h$ on $J$ and time $t=h$ JKO-Flow on $J^\eta$ for $h=\eta=1.0$
  • Figure 4: KL-distance between target measure and both JKO-flow and Wasserstein gradient flow numerically integrated for $2000$ steps with $h=2e-3$ and a range of $\eta$ values. Mean and standard deviation are over $500$ random seeds. Several JKO strengths $\eta$ in the range $[1e-6, 1e-4]$ show a consistent improvement over the $\eta=0$ baseline.

Theorems & Definitions (37)

  • Definition 1: Order-$k$ Integrator for the JKO Scheme.
  • Theorem 1: AGS2008 Theorem 4.0.4, 11.2.1
  • Definition 2: Metric Slope
  • Theorem 2: The Implicit Bias of the JKO Scheme
  • Theorem 3: smith2021onbarrett2021implicit
  • Remark 1
  • Proposition 1: Riemannian Gradient Descent Bias
  • proof : Proof of Proposition \ref{['prop:implicit_riemann_cov']}
  • Proposition 2: Implicit Bias of the JKO Scheme on the Free-Energy (Langevin).
  • proof
  • ...and 27 more