Table of Contents
Fetching ...

DDEQs: Distributional Deep Equilibrium Models through Wasserstein Gradient Flows

Jonathan Geuter, Clément Bonet, Anna Korba, David Alvarez-Melis

TL;DR

DDEQs introduce Distributional Deep Equilibrium Models that process discrete probability measures, extending traditional DEQs to permutation-agnostic inputs such as sets and point clouds. By formulating the forward pass as a Wasserstein gradient flow minimizing a squared Maximum Mean Discrepancy between a latent measure and its pushforward, they obtain fixed points in the space of measures while maintaining an implicit, efficient backward pass via phantoms and implicit gradients. The framework combines bilevel optimization, MMD-based inner objectives, and EI-compliant transformer-like architectures to achieve competitive point cloud classification and completion with significantly fewer parameters. While the forward flow can be slow, the approach demonstrates compelling performance on realistic distributional data and lays groundwork for broader applications in distributional learning and implicit neural networks.

Abstract

Deep Equilibrium Models (DEQs) are a class of implicit neural networks that solve for a fixed point of a neural network in their forward pass. Traditionally, DEQs take sequences as inputs, but have since been applied to a variety of data. In this work, we present Distributional Deep Equilibrium Models (DDEQs), extending DEQs to discrete measure inputs, such as sets or point clouds. We provide a theoretically grounded framework for DDEQs. Leveraging Wasserstein gradient flows, we show how the forward pass of the DEQ can be adapted to find fixed points of discrete measures under permutation-invariance, and derive adequate network architectures for DDEQs. In experiments, we show that they can compete with state-of-the-art models in tasks such as point cloud classification and point cloud completion, while being significantly more parameter-efficient.

DDEQs: Distributional Deep Equilibrium Models through Wasserstein Gradient Flows

TL;DR

DDEQs introduce Distributional Deep Equilibrium Models that process discrete probability measures, extending traditional DEQs to permutation-agnostic inputs such as sets and point clouds. By formulating the forward pass as a Wasserstein gradient flow minimizing a squared Maximum Mean Discrepancy between a latent measure and its pushforward, they obtain fixed points in the space of measures while maintaining an implicit, efficient backward pass via phantoms and implicit gradients. The framework combines bilevel optimization, MMD-based inner objectives, and EI-compliant transformer-like architectures to achieve competitive point cloud classification and completion with significantly fewer parameters. While the forward flow can be slow, the approach demonstrates compelling performance on realistic distributional data and lays groundwork for broader applications in distributional learning and implicit neural networks.

Abstract

Deep Equilibrium Models (DEQs) are a class of implicit neural networks that solve for a fixed point of a neural network in their forward pass. Traditionally, DEQs take sequences as inputs, but have since been applied to a variety of data. In this work, we present Distributional Deep Equilibrium Models (DDEQs), extending DEQs to discrete measure inputs, such as sets or point clouds. We provide a theoretically grounded framework for DDEQs. Leveraging Wasserstein gradient flows, we show how the forward pass of the DEQ can be adapted to find fixed points of discrete measures under permutation-invariance, and derive adequate network architectures for DDEQs. In experiments, we show that they can compete with state-of-the-art models in tasks such as point cloud classification and point cloud completion, while being significantly more parameter-efficient.

Paper Structure

This paper contains 36 sections, 20 theorems, 110 equations, 17 figures, 6 tables, 1 algorithm.

Key Result

Proposition 1

Let $\mu\in\mathcal{P}_2(\mathbb{R}^p)$, $\mathcal{F}:\mathcal{P}_2(\mathbb{R}^p)\to\mathbb{R}$, $T:\mathbb{R}^p\to\mathbb{R}^p\in L^2(\mu)$ a $\mu$-a.e. differentiable map and define $\tilde{\mathcal{F}}(\mu):=\mathcal{F}(T_\#\mu)$. Assume $\sup_x\ \|\nabla T(x)\|_{\mathrm{op}} < +\infty$. If the W

Figures (17)

  • Figure 1: Point cloud completion with DDEQs for an airplane. We add random particles to the partial input point cloud $\mathbf{X}$ (red) to create the DDEQ input $\tilde{\mathbf{X}}$ (orange), which we upscale by an invertible layer $q$. The DDEQ outputs a prediction $\mathbf{Y}^*$ (blue), which is compared against the target (green) with a MMD loss.
  • Figure 2: DDEQ Network architecture. There are two residual connection to the second and third layer norm in the top row. All encoders are standard multi-head attention network (see appendix for more details).
  • Figure 3: The point cloud classification pipeline, where $\mathbf{Z}$ is initialized independently of $\mathbf{X}$, and the DDEQ is followed by a max pool (purple) and linear (blue) layer.
  • Figure 4: MNIST-pc-partial samples.
  • Figure 5: ModelNet40-s-partial samples.
  • ...and 12 more figures

Theorems & Definitions (49)

  • Proposition 1
  • Corollary 2
  • Theorem 3
  • Definition 4: EI Property
  • Proposition 5
  • Proposition 6
  • Definition A.1: Space of Measures with bounded second Moments
  • Definition A.2: Pushforward
  • Definition A.3: Coupling
  • Definition A.4: Optimal Transport Problem
  • ...and 39 more