On-Demand Sampling: Learning Optimally from Multiple Distributions
Nika Haghtalab, Michael I. Jordan, Eric Zhao
TL;DR
The paper tackles learning across multiple data distributions by introducing on-demand sampling to achieve optimal additive sample complexity overheads relative to single-distribution learning. By casting multi-distribution learning as a stochastic convex-concave zero-sum game between a hypothesis player and a distribution-loss auditor, and solving it with no-regret online dynamics using biased but bounded stochastic gradient estimates, it derives tight upper and matching lower bounds for collaborative learning, group DRO, and agnostic federated learning. The key contributions include proving optimal additive sharing bounds of the form $Oigl( ext{ε}^{-2}( ext{log}| ext{H}| + n ext{log}(n/ ext{δ}) )igr)$ for finite hypothesis classes, extending results to Group DRO with a single smooth loss, and providing the first sample complexity bounds for Group DRO in this framework. The work also shows how these theoretical insights translate into practice via R-MDL, a resampling-based On-Demand MD learning approach that outperforms GDRO and ERM in worst-group accuracy across several datasets and remains robust to regularization choices, underscoring the practical impact of adaptive data collection in multi-distribution settings.
Abstract
Social and real-world considerations such as robustness, fairness, social welfare and multi-agent tradeoffs have given rise to multi-distribution learning paradigms, such as collaborative learning, group distributionally robust optimization, and fair federated learning. In each of these settings, a learner seeks to uniformly minimize its expected loss over $n$ predefined data distributions, while using as few samples as possible. In this paper, we establish the optimal sample complexity of these learning paradigms and give algorithms that meet this sample complexity. Importantly, our sample complexity bounds for multi-distribution learning exceed that of learning a single distribution by only an additive factor of $n \log(n) / ε^2$. This improves upon the best known sample complexity bounds for fair federated learning by Mohri et al. and collaborative learning by Nguyen and Zakynthinou by multiplicative factors of $n$ and $\log(n)/ε^3$, respectively. We also provide the first sample complexity bounds for the group DRO objective of Sagawa et al. To guarantee these optimal sample complexity bounds, our algorithms learn to sample from data distributions on demand. Our algorithm design and analysis are enabled by our extensions of online learning techniques for solving stochastic zero-sum games. In particular, we contribute stochastic variants of no-regret dynamics that can trade off between players' differing sampling costs.
