Learning with Shared Representations: Statistical Rates and Efficient Algorithms
Xiaochun Niu, Lili Su, Jiaming Xu, Pengkun Yang
TL;DR
This work addresses the problem of learning a shared, low-dimensional representation across heterogeneous clients to enable personalized and transfer learning. It introduces a spectral estimator based on two independent replicas of local averages to approximate a nonconvex least-squares objective for recovering the shared subspace $B^{\star}$, and proves sharp upper and minimax lower bounds that reveal a two-phase rate depending on the number of clients and local data. The results extend to nonlinear models, including generalized linear models and one-hidden-layer ReLU networks, and provide precise guidance on when collaboration yields benefits versus when independent learning may suffice, with privacy advantages since only local averages are exchanged. Overall, the paper advances the theoretical understanding of collaborative representation learning under heterogeneity and offers algorithmic tools that achieve optimal rates in well-represented regimes, along with practical implications for transfer learning and private fine-tuning.
Abstract
Collaborative learning through latent shared feature representations enables heterogeneous clients to train personalized models with improved performance and reduced sample complexity. Despite empirical success and extensive study, the theoretical understanding of such methods remains incomplete, even for representations restricted to low-dimensional linear subspaces. In this work, we establish new upper and lower bounds on the statistical error in learning low-dimensional shared representations across clients. Our analysis captures both statistical heterogeneity (including covariate and concept shifts) and variation in local dataset sizes, aspects often overlooked in prior work. We further extend these results to nonlinear models including logistic regression and one-hidden-layer ReLU networks. Specifically, we design a spectral estimator that leverages independent replicas of local averages to approximate the non-convex least-squares solution and derive a nearly matching minimax lower bound. Our estimator achieves the optimal statistical rate when the shared representation is well covered across clients -- i.e., when no direction is severely underrepresented. Our results reveal two distinct phases of the optimal rate: a standard parameter-counting regime and a penalized regime when the number of clients is large or local datasets are small. These findings precisely characterize when collaboration benefits the overall system or individual clients in transfer learning and private fine-tuning.
