Table of Contents
Fetching ...

Joint Hierarchical Representation Learning of Samples and Features via Informed Tree-Wasserstein Distance

Ya-Wei Eileen Lin, Ronald R. Coifman, Gal Mishne, Ronen Talmon

TL;DR

This work tackles the challenge of learning meaningful hierarchical representations in high-dimensional data where both samples and features exhibit structure. It introduces an unsupervised, iterative framework that alternates learning trees for one data mode and computing Tree-Wasserstein Distances (TWD) for the other, thereby refining both hierarchies. A data-driven Haar wavelet filtering step, constructed from the learned trees, further enhances the representations, with theoretical convergence guarantees. The learned hierarchies can initialize hyperbolic graph networks, yielding improved link prediction and node classification, and empirical results on word-document and scRNA-seq datasets show superior performance in sparse approximation and unsupervised metric learning tasks. Overall, the method provides a principled, scalable approach to jointly uncovering and leveraging hierarchical structure across both rows and columns of high-dimensional data, with practical impact for downstream hyperbolic architectures.

Abstract

High-dimensional data often exhibit hierarchical structures in both modes: samples and features. Yet, most existing approaches for hierarchical representation learning consider only one mode at a time. In this work, we propose an unsupervised method for jointly learning hierarchical representations of samples and features via Tree-Wasserstein Distance (TWD). Our method alternates between the two data modes. It first constructs a tree for one mode, then computes a TWD for the other mode based on that tree, and finally uses the resulting TWD to build the second mode's tree. By repeatedly alternating through these steps, the method gradually refines both trees and the corresponding TWDs, capturing meaningful hierarchical representations of the data. We provide a theoretical analysis showing that our method converges. We show that our method can be integrated into hyperbolic graph convolutional networks as a pre-processing technique, improving performance in link prediction and node classification tasks. In addition, our method outperforms baselines in sparse approximation and unsupervised Wasserstein distance learning tasks on word-document and single-cell RNA-sequencing datasets.

Joint Hierarchical Representation Learning of Samples and Features via Informed Tree-Wasserstein Distance

TL;DR

This work tackles the challenge of learning meaningful hierarchical representations in high-dimensional data where both samples and features exhibit structure. It introduces an unsupervised, iterative framework that alternates learning trees for one data mode and computing Tree-Wasserstein Distances (TWD) for the other, thereby refining both hierarchies. A data-driven Haar wavelet filtering step, constructed from the learned trees, further enhances the representations, with theoretical convergence guarantees. The learned hierarchies can initialize hyperbolic graph networks, yielding improved link prediction and node classification, and empirical results on word-document and scRNA-seq datasets show superior performance in sparse approximation and unsupervised metric learning tasks. Overall, the method provides a principled, scalable approach to jointly uncovering and leveraging hierarchical structure across both rows and columns of high-dimensional data, with practical impact for downstream hyperbolic architectures.

Abstract

High-dimensional data often exhibit hierarchical structures in both modes: samples and features. Yet, most existing approaches for hierarchical representation learning consider only one mode at a time. In this work, we propose an unsupervised method for jointly learning hierarchical representations of samples and features via Tree-Wasserstein Distance (TWD). Our method alternates between the two data modes. It first constructs a tree for one mode, then computes a TWD for the other mode based on that tree, and finally uses the resulting TWD to build the second mode's tree. By repeatedly alternating through these steps, the method gradually refines both trees and the corresponding TWDs, capturing meaningful hierarchical representations of the data. We provide a theoretical analysis showing that our method converges. We show that our method can be integrated into hyperbolic graph convolutional networks as a pre-processing technique, improving performance in link prediction and node classification tasks. In addition, our method outperforms baselines in sparse approximation and unsupervised Wasserstein distance learning tasks on word-document and single-cell RNA-sequencing datasets.
Paper Structure (87 sections, 10 theorems, 78 equations, 11 figures, 17 tables, 4 algorithms)

This paper contains 87 sections, 10 theorems, 78 equations, 11 figures, 17 tables, 4 algorithms.

Key Result

Theorem 4.1

The sequences $\mathbf{W}_r^{(l)}$ and $\mathbf{W}_c^{(l)}$ generated by Alg. alg:Co_HD have at least one limit point, and all limit points are fixed points if $\gamma_r, \gamma_c >0$.

Figures (11)

  • Figure 1: Overview of learning hierarchical representations across samples and features using TWD. Consider a word-document data matrix. We construct an initial tree for one data mode (e.g., words). This tree is then used to compute the TWD in the other mode (e.g., documents). The newly computed TWD informs a tree update of that mode, and the updated tree is subsequently used to compute the TWD in the cross-mode. This alternating procedure continues iteratively, refining both trees.
  • Figure 2: The $L_1$ norm of the Haar coefficients from the sample tree (left) and the feature tree (right) during the sparse approximation task across iterations on the ZEISEL dataset.
  • Figure 3: An illustration of a Haar basis induced by a binary tree.
  • Figure 4: Hierarchical structure used in the toy video recommendation example. The right tree depicts the user hierarchy based on device type and viewing context. The left tree represents the feature hierarchy of videos, branching by genre and subgenre (e.g., fiction → action, drama, sci-fi). Users and videos are colored according to their first-level subcategory.
  • Figure 5: Learned tree representations for the toy video recommendation example using Alg. \ref{['alg:Co_HD']} and Alg. \ref{['alg:wavelet_co-HD']}. Users and videos are colored according to their first-level subcategory.
  • ...and 6 more figures

Theorems & Definitions (14)

  • Theorem 4.1
  • Theorem 4.2
  • Lemma A.1
  • Theorem A.1: Theorem 1 lin2023hyperbolic
  • Theorem A.2: Theorem 1 lin2025tree.
  • Proposition A.1: Function Smoothness and Coefficient Decay gavish2010multiscale
  • Proposition A.2: $L_1$ Sparsity gavish2010multiscale
  • Remark A.1
  • proof
  • Proposition C.1
  • ...and 4 more