Generative Sliced MMD Flows with Riesz Kernels
Johannes Hertrich, Christian Wald, Fabian Altekrüger, Paul Hagemann
TL;DR
The paper addresses the computational bottleneck of maximum mean discrepancy (MMD) in high dimensions by exploiting Riesz kernels, showing that the MMD equals the sliced MMD under these kernels and enabling 1D gradient computations. For the case $r=1$, a sorting-based method reduces gradient evaluation to $O((M+N)\log(M+N))$, and a finite number of projections yields a stochastic gradient estimate with error $O(\sqrt{d/P})$, making large-scale gradient-flow training tractable. The authors formulate Generative MMD Flows using a discretized gradient flow with optional momentum and train a sequence of neural networks to approximate the steps, achieving scalable image generation on standard benchmarks. They also connect sliced MMD to the Wasserstein-1 distance, provide explicit constants, and validate the approach with extensive experiments on MNIST, FashionMNIST, CIFAR10, and CelebA. Overall, the work offers a practical, efficient framework for gradient-flow-based generative modelling via sliced MMD with Riesz kernels and demonstrates strong empirical performance.
Abstract
Maximum mean discrepancy (MMD) flows suffer from high computational costs in large scale computations. In this paper, we show that MMD flows with Riesz kernels $K(x,y) = - \|x-y\|^r$, $r \in (0,2)$ have exceptional properties which allow their efficient computation. We prove that the MMD of Riesz kernels, which is also known as energy distance, coincides with the MMD of their sliced version. As a consequence, the computation of gradients of MMDs can be performed in the one-dimensional setting. Here, for $r=1$, a simple sorting algorithm can be applied to reduce the complexity from $O(MN+N^2)$ to $O((M+N)\log(M+N))$ for two measures with $M$ and $N$ support points. As another interesting follow-up result, the MMD of compactly supported measures can be estimated from above and below by the Wasserstein-1 distance. For the implementations we approximate the gradient of the sliced MMD by using only a finite number $P$ of slices. We show that the resulting error has complexity $O(\sqrt{d/P})$, where $d$ is the data dimension. These results enable us to train generative models by approximating MMD gradient flows by neural networks even for image applications. We demonstrate the efficiency of our model by image generation on MNIST, FashionMNIST and CIFAR10.
