Table of Contents
Fetching ...

Deep Kuratowski Embedding Neural Networks for Wasserstein Metric Learning

Andrew Qing He

Abstract

Computing pairwise Wasserstein distances is a fundamental bottleneck in data analysis pipelines. Motivated by the classical Kuratowski embedding theorem, we propose two neural architectures for learning to approximate the Wasserstein-2 distance ($W_2$) from data. The first, DeepKENN, aggregates distances across all intermediate feature maps of a CNN using learnable positive weights. The second, ODE-KENN, replaces the discrete layer stack with a Neural ODE, embedding each input into the infinite-dimensional Banach space $C^1([0,1], \mathbb{R}^d)$ and providing implicit regularization via trajectory smoothness. Experiments on MNIST with exact precomputed $W_2$ distances show that ODE-KENN achieves a 28% lower test MSE than the single-layer baseline and 18% lower than DeepKENN under matched parameter counts, while exhibiting a smaller generalization gap. The resulting fast surrogate can replace the expensive $W_2$ oracle in downstream pairwise distance computations.

Deep Kuratowski Embedding Neural Networks for Wasserstein Metric Learning

Abstract

Computing pairwise Wasserstein distances is a fundamental bottleneck in data analysis pipelines. Motivated by the classical Kuratowski embedding theorem, we propose two neural architectures for learning to approximate the Wasserstein-2 distance () from data. The first, DeepKENN, aggregates distances across all intermediate feature maps of a CNN using learnable positive weights. The second, ODE-KENN, replaces the discrete layer stack with a Neural ODE, embedding each input into the infinite-dimensional Banach space and providing implicit regularization via trajectory smoothness. Experiments on MNIST with exact precomputed distances show that ODE-KENN achieves a 28% lower test MSE than the single-layer baseline and 18% lower than DeepKENN under matched parameter counts, while exhibiting a smaller generalization gap. The resulting fast surrogate can replace the expensive oracle in downstream pairwise distance computations.

Paper Structure

This paper contains 41 sections, 1 theorem, 10 equations, 1 figure, 2 tables.

Key Result

Theorem 1

Every bounded metric space $(M, d)$ embeds isometrically as a closed subset of a convex set in the Banach space $\ell^\infty(M)$. $\blacktriangleleft$$\blacktriangleleft$

Figures (1)

  • Figure 1: Experimental results for all three models trained for 2,000 epochs on MNIST $W_2$ distance learning. Top left: Training (dashed) and validation (solid) MSE loss curves on a log scale. ODE-KENN (green) converges fastest and maintains the smallest train/val gap throughout training. Top right: Predicted vs. true $W_2$ distances on the test set (2,750 pairs). ODE-KENN produces the tightest scatter around the ideal diagonal, particularly at small $W_2$ values. Bottom left: Learned layer weights $\lambda_k$ for DeepKENN. The first two convolutional layers are nearly fully suppressed ($\lambda \approx 0$), while FC1 dominates ($\lambda \approx 1.62$), indicating that abstract compressed representations are most informative for $W_2$. Bottom right: Learned time weights $\lambda_t$ for ODE-KENN over $t \in [0,1]$. The bell-shaped profile peaks near $t \approx 0.35$, downweighting the raw feature ($t=0$) and the late trajectory, and assigning greatest importance to the early-to-mid ODE evolution.

Theorems & Definitions (1)

  • Theorem 1: Kuratowski--Wojdysławski