DDEQs: Distributional Deep Equilibrium Models through Wasserstein Gradient Flows
Jonathan Geuter, Clément Bonet, Anna Korba, David Alvarez-Melis
TL;DR
DDEQs introduce Distributional Deep Equilibrium Models that process discrete probability measures, extending traditional DEQs to permutation-agnostic inputs such as sets and point clouds. By formulating the forward pass as a Wasserstein gradient flow minimizing a squared Maximum Mean Discrepancy between a latent measure and its pushforward, they obtain fixed points in the space of measures while maintaining an implicit, efficient backward pass via phantoms and implicit gradients. The framework combines bilevel optimization, MMD-based inner objectives, and EI-compliant transformer-like architectures to achieve competitive point cloud classification and completion with significantly fewer parameters. While the forward flow can be slow, the approach demonstrates compelling performance on realistic distributional data and lays groundwork for broader applications in distributional learning and implicit neural networks.
Abstract
Deep Equilibrium Models (DEQs) are a class of implicit neural networks that solve for a fixed point of a neural network in their forward pass. Traditionally, DEQs take sequences as inputs, but have since been applied to a variety of data. In this work, we present Distributional Deep Equilibrium Models (DDEQs), extending DEQs to discrete measure inputs, such as sets or point clouds. We provide a theoretically grounded framework for DDEQs. Leveraging Wasserstein gradient flows, we show how the forward pass of the DEQ can be adapted to find fixed points of discrete measures under permutation-invariance, and derive adequate network architectures for DDEQs. In experiments, we show that they can compete with state-of-the-art models in tasks such as point cloud classification and point cloud completion, while being significantly more parameter-efficient.
