Table of Contents
Fetching ...

Quantitative Convergence of Wasserstein Gradient Flows of Kernel Mean Discrepancies

Lénaïc Chizat, Maria Colombo, Roberto Colombo, Xavier Fernández-Real

Abstract

We study the quantitative convergence of Wasserstein gradient flows of Kernel Mean Discrepancy (KMD) (also known as Maximum Mean Discrepancy (MMD)) functionals. Our setting covers in particular the training dynamics of shallow neural networks in the infinite-width and continuous time limit, as well as interacting particle systems with pairwise Riesz kernel interaction in the mean-field and overdamped limit. Our main analysis concerns the model case of KMD functionals given by the squared Sobolev distance $ \mathscr{E}^ν_{s}(μ)= \frac{1}{2}\lVert μ-ν\rVert_{\dot H^{-s}}^{2}$ for any $s\geq 1 $ and $ν$ a fixed probability measure on the $d$-dimensional torus. First, inspired by Yudovich theory for the $2d$-Euler equation, we establish existence and uniqueness in natural weak regularity classes. Next, we show that for $s=1$ the flow converges globally at an exponential rate under minimal assumptions, while for $s>1$ we prove local convergence at polynomial rates that depend explicitly on $s$ and on the Sobolev regularity of $μ$ and $ν$. These rates hold both at the energy level and in higher regularity classes and are tight for $ν$ uniform. We then consider the gradient flow of the population loss for shallow neural networks with ReLU activation, which can be cast as a Wasserstein--Fisher--Rao gradient flow on the space of nonnegative measures on the sphere $\mathbb{S}^d$. Exploiting a correspondence with the Sobolev energy case with $s=(d+3)/2$, we derive an explicit polynomial local convergence rate for this dynamics. Except for the special case $s=1$, even non-quantitative convergence was previously open in all these settings. We also include numerical experiments in dimension $d=1$ using both PDE and particle methods which illustrate our analysis.

Quantitative Convergence of Wasserstein Gradient Flows of Kernel Mean Discrepancies

Abstract

We study the quantitative convergence of Wasserstein gradient flows of Kernel Mean Discrepancy (KMD) (also known as Maximum Mean Discrepancy (MMD)) functionals. Our setting covers in particular the training dynamics of shallow neural networks in the infinite-width and continuous time limit, as well as interacting particle systems with pairwise Riesz kernel interaction in the mean-field and overdamped limit. Our main analysis concerns the model case of KMD functionals given by the squared Sobolev distance for any and a fixed probability measure on the -dimensional torus. First, inspired by Yudovich theory for the -Euler equation, we establish existence and uniqueness in natural weak regularity classes. Next, we show that for the flow converges globally at an exponential rate under minimal assumptions, while for we prove local convergence at polynomial rates that depend explicitly on and on the Sobolev regularity of and . These rates hold both at the energy level and in higher regularity classes and are tight for uniform. We then consider the gradient flow of the population loss for shallow neural networks with ReLU activation, which can be cast as a Wasserstein--Fisher--Rao gradient flow on the space of nonnegative measures on the sphere . Exploiting a correspondence with the Sobolev energy case with , we derive an explicit polynomial local convergence rate for this dynamics. Except for the special case , even non-quantitative convergence was previously open in all these settings. We also include numerical experiments in dimension using both PDE and particle methods which illustrate our analysis.
Paper Structure (38 sections, 27 theorems, 334 equations, 4 figures)

This paper contains 38 sections, 27 theorems, 334 equations, 4 figures.

Key Result

Proposition 1.1

Let $s\ge 1$ and let $\mathscr{X}_{s}(\mathbb{T}^{d})$ be given by eq:def-yudovich-class. Then, for every $\bar{\mu} , \nu \in \mathscr{P}\cap \mathscr{X}_{s}(\mathbb{T}^{d})$ there exist a maximal time of existence $T>0$ and a unique maximal solution $\mu \in L_{\mathop{\mathrm{loc}}\nolimits}^{\in Finally, solutions propagate Hölder and Sobolev regularity (see prop:propagation-regularity and pro

Figures (4)

  • Figure 1: Regularity regimes of Riesz kernels $K_s$ on the $d$-torus (bracket orientation indicates inclusion/exclusion of the endpoint) and some values of particular interest (arrows). The negative distance and ReLU neural network cases are not strictly speaking Riesz kernels on the $d$-torus, but their regularity on the diagonal matches that of the indicated exponent. For $d=1$, the kernel is never singular for $s\geq 1$ (as then the Coulomb case coincides with the negative distance case).
  • Figure 2: The case $s=1$ and $d=1$ integrated with finite volume discretization with upwind scheme. (a) For $s=1$, the regularity of $\bar{\mu}$ and $\nu$ do not impact the exponential convergence rate. Here $\gamma_0$ (resp. $\gamma_\nu$) is the largest scalar such that $\bar{\mu} \in \dot{H}^{\gamma_0}$ (resp. $\nu \in \dot{H}^{\gamma_\nu}$) and the densities are lower-bounded by $0.2$ so our theoretical rate is $O(e^{-0.2t})$. (b) As theory predicts, the convergence rate is upper-bounded by the minimum density of $\nu$, and independent of the minimum density of $\bar{\mu}$. Here $m_0$ (resp. $m_\nu$) indicates the minimum of $\bar{\mu}$ (resp. $\nu$). In this experiment, $(\gamma_0,\gamma_\nu)=(1,1)$. (c) When $\nu$ has a positive lower bound, the zero-density areas of $\bar{\mu}$ (in black) shrink exponentially fast, as theory predicts (here $(\gamma_0,\gamma_\nu)=(1,1)$); see \ref{['rmk:exponential-filling']} and \ref{['lem:exponential-filling-holes']}.
  • Figure 3: The case $s=2$ and $d=1$ integrated with finite volume discretization with upwind scheme. We write $\gamma_0$ (resp. $\gamma_\nu$) for the largest scalar such that $\bar{\mu} \in \dot{H}^{\gamma_0}$ (resp. $\nu \in \dot{H}^{\gamma_\nu}$) and the densities are lower-bounded by $0.2$. (a) Snapshots of the density of $\mu_t$ along the evolution (here $(\gamma_0,\gamma_\nu)=(2,2)$). We can observe the absence of a maximum principle (cf. $x=0.9$), and the appearance of high frequency components (cf. $x=0.8$). Both phenomena are absent from the case $s=1$. (b) Convergence rate in $\dot{H}^{-2}$ norm. The dashed lines show approximate fits of the asymptotic rates.
  • Figure 4: The case of shallow ReLU Neural Network with $d=1$, implemented via gradient descent on the population loss with a "student" and "teacher" neural network of width $800$ that discretize the measure $\mu_t$ and $\nu$ respectively (in other words, this is an interacting particle system approximation of the PDE with $800$ particles for each measure). We initialize $\bar{\mu}$ with a uniform density ($\gamma_0=+\infty$) and vary the target regularity (we indicate the largest $\gamma_\nu$ such that $\nu \in \dot{H}^{\gamma_\nu}$). We compare the Wasserstein (W) and Wasserstein--Fisher--Rao (WFR) dynamics. Although (WFR)'s energy decay is slightly faster at initialization (thanks to the extra term in the energy dissipation formula \ref{['eq:energy-dissipation-Fisher-Rao']}), it does not converge faster in general. The dashed lines show approximate fits of the asymptotic rates.

Theorems & Definitions (68)

  • Proposition 1.1: Local well-posedness
  • Theorem 1.2: Global convergence to the target: $s=1$
  • Remark 1.3: Sharpness of the lower bounds
  • Theorem 1.4: Local convergence to the target: $s>1$
  • Remark 1.5: On the locality assumption
  • Remark 1.6: Sharpness of the polynomial decay rate
  • Theorem 1.7: Convergence for neural networks
  • Remark 1.8
  • Remark 1.9
  • Corollary 1.10
  • ...and 58 more