Implicit Bias of the JKO Scheme
Peter Halmos, Boris Hanin
TL;DR
The paper analyzes the implicit bias of the Jordan–Kinderlehrer–Otto (JKO) scheme for Wasserstein gradient flows. It proves a second-order modification of the objective, $J^\eta(\rho)=J(\rho)-\frac{\eta}{4}|\partial J(\rho)|^2$, so that the Wasserstein gradient flow on $J^\eta$ matches the JKO updates up to $O(\eta^2)$. This reveals a family of implicit biases: for entropy the bias is Fisher information, for KL divergence the Hyvärinen divergence, and for Riemannian gradient flows a kinetic-energy–type penalty, with extensions to free-energy/Langevin dynamics and to the Riemannian setting. The authors develop a BEA-based strategy to identify the modified flow and validate the theory through numerical experiments in Bures–Wasserstein space and one-dimensional quartic KL scenarios, demonstrating improved accuracy and stability of the JKO-Flow. These results provide a deeper understanding of the geometry of JKO updates and offer practical implications for stable discretizations of gradient flows on manifolds and in stochastic sampling.
Abstract
Wasserstein gradient flow provides a general framework for minimizing an energy functional $J$ over the space of probability measures on a Riemannian manifold $(M,g)$. Its canonical time-discretization, the Jordan-Kinderlehrer-Otto (JKO) scheme, produces for any step size $η>0$ a sequence of probability distributions $ρ_k^η$ that approximate to first order in $η$ Wasserstein gradient flow on $J$. But the JKO scheme also has many other remarkable properties not shared by other first order integrators, e.g. it preserves energy dissipation and exhibits unconditional stability for $λ$-geodesically convex functionals $J$. To better understand the JKO scheme we characterize its implicit bias at second order in $η$. We show that $ρ_k^η$ are approximated to order $η^2$ by Wasserstein gradient flow on a modified energy \[ J^η(ρ) = J(ρ) - \fracη{4}\int_M \Big\lVert \nabla_g \frac{δJ}{δρ} (ρ) \Big\rVert_{2}^{2} \,ρ(dx), \] obtained by subtracting from $J$ the squared metric curvature of $J$ times $η/4$. The JKO scheme therefore adds at second order in $η$ a deceleration in directions where the metric curvature of $J$ is rapidly changing. This corresponds to canonical implicit biases for common functionals: for entropy the implicit bias is the Fisher information, for KL-divergence it is the Fisher-Hyv{ä}rinen divergence, and for Riemannian gradient descent it is the kinetic energy in the metric $g$. To understand the differences between minimizing $J$ and $J^η$ we study JKO-Flow, Wasserstein gradient flow on $J^η$, in several simple numerical examples. These include exactly solvable Langevin dynamics on the Bures-Wasserstein space and Langevin sampling from a quartic potential in 1D.
