Table of Contents
Fetching ...

Fast unsupervised ground metric learning with tree-Wasserstein distance

Kira M. Düsterwald, Samo Hromadka, Makoto Yamada

TL;DR

The paper addresses learning an informative, unsupervised ground metric for clustering by reframing the problem with a tree-based, low-rank approximation of Wasserstein singular vectors. It introduces Tree-WSV, embedding samples and features as leaves on trees and using the tree-Wasserstein distance to approximate full WSVs with a theoretical and practical complexity reduction to $O(n^3+m^3+mn)$ per power iteration. The authors establish existence and uniqueness results for the tree-WSV solutions, provide a fast recursive method to compute the leaf-path basis, and validate the approach on toy data and large-scale single-cell RNA-sequencing datasets, demonstrating favorable speed and competitive clustering accuracy. This work suggests a scalable, general-purpose tool for unsupervised ground metric learning with broad applications in high-dimensional, unlabeled data analysis.

Abstract

The performance of unsupervised methods such as clustering depends on the choice of distance metric between features, or ground metric. Commonly, ground metrics are decided with heuristics or learned via supervised algorithms. However, since many interesting datasets are unlabelled, unsupervised ground metric learning approaches have been introduced. One promising option employs Wasserstein singular vectors (WSVs), which emerge when computing optimal transport distances between features and samples simultaneously. WSVs are effective, but can be prohibitively computationally expensive in some applications: $\mathcal{O}(n^2m^2(n \log(n) + m \log(m))$ for $n$ samples and $m$ features. In this work, we propose to augment the WSV method by embedding samples and features on trees, on which we compute the tree-Wasserstein distance (TWD). We demonstrate theoretically and empirically that the algorithm converges to a better approximation of the standard WSV approach than the best known alternatives, and does so with $\mathcal{O}(n^3+m^3+mn)$ complexity. In addition, we prove that the initial tree structure can be chosen flexibly, since tree geometry does not constrain the richness of the approximation up to the number of edge weights. This proof suggests a fast and recursive algorithm for computing the tree parameter basis set, which we find crucial to realising the efficiency gains at scale. Finally, we employ the tree-WSV algorithm to several single-cell RNA sequencing genomics datasets, demonstrating its scalability and utility for unsupervised cell-type clustering problems. These results poise unsupervised ground metric learning with TWD as a low-rank approximation of WSV with the potential for widespread application.

Fast unsupervised ground metric learning with tree-Wasserstein distance

TL;DR

The paper addresses learning an informative, unsupervised ground metric for clustering by reframing the problem with a tree-based, low-rank approximation of Wasserstein singular vectors. It introduces Tree-WSV, embedding samples and features as leaves on trees and using the tree-Wasserstein distance to approximate full WSVs with a theoretical and practical complexity reduction to per power iteration. The authors establish existence and uniqueness results for the tree-WSV solutions, provide a fast recursive method to compute the leaf-path basis, and validate the approach on toy data and large-scale single-cell RNA-sequencing datasets, demonstrating favorable speed and competitive clustering accuracy. This work suggests a scalable, general-purpose tool for unsupervised ground metric learning with broad applications in high-dimensional, unlabeled data analysis.

Abstract

The performance of unsupervised methods such as clustering depends on the choice of distance metric between features, or ground metric. Commonly, ground metrics are decided with heuristics or learned via supervised algorithms. However, since many interesting datasets are unlabelled, unsupervised ground metric learning approaches have been introduced. One promising option employs Wasserstein singular vectors (WSVs), which emerge when computing optimal transport distances between features and samples simultaneously. WSVs are effective, but can be prohibitively computationally expensive in some applications: for samples and features. In this work, we propose to augment the WSV method by embedding samples and features on trees, on which we compute the tree-Wasserstein distance (TWD). We demonstrate theoretically and empirically that the algorithm converges to a better approximation of the standard WSV approach than the best known alternatives, and does so with complexity. In addition, we prove that the initial tree structure can be chosen flexibly, since tree geometry does not constrain the richness of the approximation up to the number of edge weights. This proof suggests a fast and recursive algorithm for computing the tree parameter basis set, which we find crucial to realising the efficiency gains at scale. Finally, we employ the tree-WSV algorithm to several single-cell RNA sequencing genomics datasets, demonstrating its scalability and utility for unsupervised cell-type clustering problems. These results poise unsupervised ground metric learning with TWD as a low-rank approximation of WSV with the potential for widespread application.

Paper Structure

This paper contains 26 sections, 4 theorems, 19 equations, 4 figures, 2 tables, 2 algorithms.

Key Result

Proposition 2.1

The WSV fixed point equations eq1 can be expressed on the trees $\mathcal{T}_A, \mathcal{T}_B$ as: where the singular vector update is to find $\boldsymbol{w_A}$ (and symmetrically $\boldsymbol{w_B}$) such that $\forall i,j$, where $\boldsymbol{z_i^{(A)}}$ is the $i$th column of $\boldsymbol{Z^{(A)}}$, $\circ$ denotes element-wise product and $(\lambda_A,\lambda_B) \in (\mathbb{R_+^*})^2$.

Figures (4)

  • Figure 1: Tree embeddings for samples $\boldsymbol{A}$ as leaves in $\mathcal{T}_{\boldsymbol{A}}$ (left) and features $\boldsymbol{B}$ as leaves in $\mathcal{T}_{\boldsymbol{B}}$ (right). The tree metric $d_{\mathcal{T}_{\boldsymbol{A}}}(\boldsymbol{a_2}, \boldsymbol{a_3})$ is shown as the shortest path between these leaves in orange on the left. Equivalently, we can use the relative expression of the features $\{\boldsymbol{b_1},\boldsymbol{b_2},\boldsymbol{b_3},\boldsymbol{b_4},\boldsymbol{b_5}\}$ that represent $\boldsymbol{a_2},\boldsymbol{a_3}$ respectively (as shown in different hues of green on the right) to compute a TWD, $\mathcal{W}_{\mathcal{T}_{\boldsymbol{B}}}(\boldsymbol{a_2},\boldsymbol{a_3})$, in $\mathcal{T}_{\boldsymbol{B}}$. We assume $d_{\mathcal{T}_{\boldsymbol{A}}}(\boldsymbol{a_2}, \boldsymbol{a_3})$ is equal to $\mathcal{W}_{\mathcal{T}_{\boldsymbol{B}}}(\boldsymbol{a_2},\boldsymbol{a_3})$ to learn a good embedding.
  • Figure 2: Comparison of (A) mean time complexity and (B) Frobenius norm error for different ground metric learning algorithms (lower is better). Tree algorithms are specified by restrictions on the minimum number of children per node in the tree. Standard error of the mean is shown with black bars. "Many children" refers to ClusterTree initialised with 10 for the min. children parameter.
  • Figure 3: UMAP embeddings based on the learned tree-singular vectors for cells, and coloured by provided annotation: (A) PBMCs Wolf2018; (B) Lung tissue Sikkema2023; (C) Mouse V1 neural tissue, by broad types Tasic2018; and (D) GABAergic neurons in the same mouse V1 neural tissue set as C, by subtype Tasic2018.
  • Figure 4: Empirical convergence observed for different toy torus dataset sizes across power iterations. The smaller dimension is shown in blue: (A) 60 x 80 (with SVD basis set calculation); (B) 100 x 200 (with SVD basis set calculation); (C) 1500 x 300. (with Algorithm \ref{['recurse']} basis set calculation).

Theorems & Definitions (4)

  • Proposition 2.1
  • Lemma 2.2
  • Lemma 2.3
  • Theorem 2.4