An Optimal Transport Approach for Network Regression
Alex G. Zalles, Kai M. Hung, Ann E. Finneran, Lydia Beaudrot, César A. Uribe
TL;DR
We address network regression where the response is a graph and covariates are Euclidean vectors. The method represents graphs as multivariate Gaussians via the Laplacian pseudoinverse and performs regression in the Wasserstein space through Fréchet means, reducing the problem to a weighted Wasserstein barycenter computation. The approach uses entropy-regularized fixed-point iterations and covariance shifts to handle degeneracies, with convergence evidence across synthetic and real data and clear improvements over Frobenius-based regression in prediction accuracy and scalability. This work motivates further development of convergence theory for Fréchet means in Wasserstein spaces and extensions to Gromov-Wasserstein distances for graphs of varying sizes.
Abstract
We study the problem of network regression, where one is interested in how the topology of a network changes as a function of Euclidean covariates. We build upon recent developments in generalized regression models on metric spaces based on Fréchet means and propose a network regression method using the Wasserstein metric. We show that when representing graphs as multivariate Gaussian distributions, the network regression problem requires the computation of a Riemannian center of mass (i.e., Fréchet means). Fréchet means with non-negative weights translates into a barycenter problem and can be efficiently computed using fixed point iterations. Although the convergence guarantees of fixed-point iterations for the computation of Wasserstein affine averages remain an open problem, we provide evidence of convergence in a large number of synthetic and real-data scenarios. Extensive numerical results show that the proposed approach improves existing procedures by accurately accounting for graph size, topology, and sparsity in synthetic experiments. Additionally, real-world experiments using the proposed approach result in higher Coefficient of Determination ($R^{2}$) values and lower mean squared prediction error (MSPE), cementing improved prediction capabilities in practice.
