Table of Contents
Fetching ...

Learning Elastic Costs to Shape Monge Displacements

Michal Klein, Aram-Alexandre Pooladian, Pierre Ablin, Eugène Ndiaye, Jonathan Niles-Weed, Marco Cuturi

TL;DR

This work extends optimal transport by introducing elastic costs $h(z)=\tfrac{1}{2}\|z\|^2+\gamma\tau(z)$ to shape Monge map displacements via the proximal operator of $\tau$. It provides a practical route to compute OT maps for any elastic cost using the MBO estimator, and introduces a bilevel learning scheme to infer the regularizer parameter $\theta$ that enforces low-dimensional displacement subspaces. By focusing on subspace elastic costs $\tau_{A^\perp}$, the authors establish statistical guarantees, connect to the spiked transport model, and show that the effective estimation rate can depend on the subspace dimension rather than the ambient dimension. Through synthetic and single-cell data experiments, they demonstrate ground-truth map generation, subspace recovery, and improved predictive performance when learning displacement structure, highlighting the method's potential for structured OT in high dimensions.

Abstract

Given a source and a target probability measure supported on $\mathbb{R}^d$, the Monge problem asks to find the most efficient way to map one distribution to the other. This efficiency is quantified by defining a \textit{cost} function between source and target data. Such a cost is often set by default in the machine learning literature to the squared-Euclidean distance, $\ell^2_2(\mathbf{x},\mathbf{y})=\tfrac12|\mathbf{x}-\mathbf{y}|_2^2$. Recently, Cuturi et. al '23 highlighted the benefits of using elastic costs, defined through a regularizer $τ$ as $c(\mathbf{x},\mathbf{y})=\ell^2_2(\mathbf{x},\mathbf{y})+τ(\mathbf{x}-\mathbf{y})$. Such costs shape the \textit{displacements} of Monge maps $T$, i.e., the difference between a source point and its image $T(\mathbf{x})-\mathbf{x})$, by giving them a structure that matches that of the proximal operator of $τ$. In this work, we make two important contributions to the study of elastic costs: (i) For any elastic cost, we propose a numerical method to compute Monge maps that are provably optimal. This provides a much-needed routine to create synthetic problems where the ground truth OT map is known, by analogy to the Brenier theorem, which states that the gradient of any convex potential is always a valid Monge map for the $\ell_2^2$ cost; (ii) We propose a loss to \textit{learn} the parameter $θ$ of a parameterized regularizer $τ_θ$, and apply it in the case where $τ_{A}(\mathbf{z})=|A^\perp \mathbf{z}|^2_2$. This regularizer promotes displacements that lie on a low dimensional subspace of $\mathbb{R}^d$, spanned by the $p$ rows of $A\in\mathbb{R}^{p\times d}$.

Learning Elastic Costs to Shape Monge Displacements

TL;DR

This work extends optimal transport by introducing elastic costs to shape Monge map displacements via the proximal operator of . It provides a practical route to compute OT maps for any elastic cost using the MBO estimator, and introduces a bilevel learning scheme to infer the regularizer parameter that enforces low-dimensional displacement subspaces. By focusing on subspace elastic costs , the authors establish statistical guarantees, connect to the spiked transport model, and show that the effective estimation rate can depend on the subspace dimension rather than the ambient dimension. Through synthetic and single-cell data experiments, they demonstrate ground-truth map generation, subspace recovery, and improved predictive performance when learning displacement structure, highlighting the method's potential for structured OT in high dimensions.

Abstract

Given a source and a target probability measure supported on , the Monge problem asks to find the most efficient way to map one distribution to the other. This efficiency is quantified by defining a \textit{cost} function between source and target data. Such a cost is often set by default in the machine learning literature to the squared-Euclidean distance, . Recently, Cuturi et. al '23 highlighted the benefits of using elastic costs, defined through a regularizer as . Such costs shape the \textit{displacements} of Monge maps , i.e., the difference between a source point and its image , by giving them a structure that matches that of the proximal operator of . In this work, we make two important contributions to the study of elastic costs: (i) For any elastic cost, we propose a numerical method to compute Monge maps that are provably optimal. This provides a much-needed routine to create synthetic problems where the ground truth OT map is known, by analogy to the Brenier theorem, which states that the gradient of any convex potential is always a valid Monge map for the cost; (ii) We propose a loss to \textit{learn} the parameter of a parameterized regularizer , and apply it in the case where . This regularizer promotes displacements that lie on a low dimensional subspace of , spanned by the rows of .
Paper Structure (22 sections, 7 theorems, 44 equations, 5 figures)

This paper contains 22 sections, 7 theorems, 44 equations, 5 figures.

Key Result

Proposition 1

Let $\mu$ be a measure in $\mathcal{P}(\mathbb{R}^d)$. Consider a potential $g:\mathbb{R}^d\rightarrow \mathbb{R}$ and its $h$-transform as defined in eq:htransf. Additionally, set $T_g^h\coloneqq \mathop{\mathrm{Id}}\nolimits - \nabla h^\star\circ \nabla \bar{g}^h$. Then $T_g^h$ is the OT Monge1781

Figures (5)

  • Figure 1: Illustration of ground truth optimal transport maps with different costs $h$, for the same base function $g$. In this experiment, $g$ is the negative of a random ICNN with 2-dimensional inputs, 3 layers and hidden dimensions of sizes $[8,8,8]$. All plots display the level lines of $g$. The optimal transport map $T_g^h$ are recomputed four times using Prop. \ref{['prop:pushforward']}, with four different costs $h$, displayed above each plot. (left) When $h$ is the usual $\ell_2^2$ cost, we observe a typical OT map that follows from each $\mathbf{x}_i$, minus the gradient of $g$. With a sparsity-inducing regularizer (middle-left), we obtain sparse displacements: most arrows follow either of the two canonical axes, some points do not move at all. (middle-right) With a cost that penalizes displacements that are orthogonal to a vector $\mathbf{b}$, we obtain displacements that push further to the bottom than in the (left) plot. When the penalization strength is increased (right), the displacements are increasingly parallel to $\mathbf{b}$. When $\mathbf{b}$ is not known beforehand, and both source and target measures are given, we present a general procedure that proposes to learn adaptively such a parameter in § \ref{['sec:learning_structure']}.
  • Figure 2: Illustration of the $h$-transform computation in $2$d. (left): base concave potential $g$, here a negative quadratic. (other figures) Level lines of the corresponding $h$-transform $\bar{g}^h$ for different choices of $h$. The $h$-transform is computed using the iterations described in Prop. \ref{['prop: prox_descent']}.
  • Figure 3: Performance of the MBO estimator on two ground-truth tasks involving the $\tau=\ell_1$ and $\tau_{A^\perp}=\|A^\perp\mathbf{z}\|^2_2$ structured costs, where $p=2$ in dimension $d=5$ (two figures to the left) and dimension $d=10$ (two figures to the right). We display the MSE ratio between the MSE estimated with a regularizer strength $\gamma > 0$ and that in the absence of regularization (i.e., $\gamma=0$). The level of regularization used for generating the ground truth data is $\gamma^*$, whereas performance are shown varying w.r.t. $\gamma$. We display curves $\pm$ s.t.d. estimated over 10 random seeds.
  • Figure 4: Error averaged over 5 seeded runs (lower is better) in $[0,1]$ of the $\hat{p}\times d$ orthogonal matrix $\hat{A}$ recovered by our algorithm, compared to the ground truth $p^*\times d$ cost matrix $A^*$. Error bars are not shown for compactness, but are negligible since all quantities are bounded below and close to $0$. Dimensions $d, p^*$ vary in each of these 6 plots, whereas $\hat{p}$ is fixed to either $p^*$ (top row) or $1.25 p^*$ (bottom row). Error is quantified as the normalized squared-residual error obtained when projecting the $p^*$ basis vectors of $A^*$ onto the span of $\hat{A}$. From left to right, the regularization strength $\gamma^*$ increases to ensure that 50%, 70% and 90% of the total inertia of all displacements generated by the ground-truth Monge1781 map are borne by the $p^*$ highest singular values. As expected, recovery is easier when $\hat{p}$ is slightly larger than $p^*$ (bottom) compared to being exactly equal (top). It is also easier as the share of inertia captured by $p^*$ increases.
  • Figure 5: Predictive performance of the MBO estimator on single-cell datasets, $d=256$, using either the naive baseline $\ell_2^2$ cost (black dotted line) or elastic subspace cost \ref{['eq:tauorth']}, with varying $\gamma$ and $\hat{p}$. Remarkably, promoting displacements to happen in a subspace of much lower dimension improves predictions, even when measured in the squared-Euclidean distance.

Theorems & Definitions (14)

  • Definition 1: MBO Estimator
  • Proposition 1
  • Proposition 2
  • proof
  • Theorem 1
  • Theorem 2: divol2022optimal
  • Definition 2: Elastic Costs Loss
  • Proposition 3
  • proof : Proof of \ref{['prop: opt_characterization']}
  • Lemma 1
  • ...and 4 more