Tree-Wasserstein Distance for High Dimensional Data with a Latent Feature Hierarchy
Ya-Wei Eileen Lin, Ronald R. Coifman, Gal Mishne, Ronen Talmon
TL;DR
This work tackles the challenge of measuring meaningful distances for high‑dimensional data when features possess a latent hierarchical structure. It learns a latent feature tree by embedding features in a multi‑scale hyperbolic space using diffusion geometry and then decoding a binary feature tree via a hyperbolic diffusion LCA (HD‑LCA), producing a data‑driven Tree‑Wasserstein Distance $ exttt{TW}(oldsymbol{x}_i,oldsymbol{x}_{i'},B)$ that matches the ground‑truth latent metric $d_T$. The authors prove that the decoded tree yields a bilipschitz equivalent distance to the true latent TW and demonstrate scalable, linear‑time computation using diffusion landmarks; empirical results on synthetic data, word‑document classifications, and single‑cell RNA sequencing show clear advantages over existing TWD methods and pre‑trained baselines. The approach enables unsupervised ground‑metric learning from observations and provides a differentiable, geometry‑aware distance that improves downstream tasks while remaining efficient for large feature sets.
Abstract
Finding meaningful distances between high-dimensional data samples is an important scientific task. To this end, we propose a new tree-Wasserstein distance (TWD) for high-dimensional data with two key aspects. First, our TWD is specifically designed for data with a latent feature hierarchy, i.e., the features lie in a hierarchical space, in contrast to the usual focus on embedding samples in hyperbolic space. Second, while the conventional use of TWD is to speed up the computation of the Wasserstein distance, we use its inherent tree as a means to learn the latent feature hierarchy. The key idea of our method is to embed the features into a multi-scale hyperbolic space using diffusion geometry and then present a new tree decoding method by establishing analogies between the hyperbolic embedding and trees. We show that our TWD computed based on data observations provably recovers the TWD defined with the latent feature hierarchy and that its computation is efficient and scalable. We showcase the usefulness of the proposed TWD in applications to word-document and single-cell RNA-sequencing datasets, demonstrating its advantages over existing TWDs and methods based on pre-trained models.
