Table of Contents
Fetching ...

On-Demand Sampling: Learning Optimally from Multiple Distributions

Nika Haghtalab, Michael I. Jordan, Eric Zhao

TL;DR

The paper tackles learning across multiple data distributions by introducing on-demand sampling to achieve optimal additive sample complexity overheads relative to single-distribution learning. By casting multi-distribution learning as a stochastic convex-concave zero-sum game between a hypothesis player and a distribution-loss auditor, and solving it with no-regret online dynamics using biased but bounded stochastic gradient estimates, it derives tight upper and matching lower bounds for collaborative learning, group DRO, and agnostic federated learning. The key contributions include proving optimal additive sharing bounds of the form $Oigl( ext{ε}^{-2}( ext{log}| ext{H}| + n ext{log}(n/ ext{δ}) )igr)$ for finite hypothesis classes, extending results to Group DRO with a single smooth loss, and providing the first sample complexity bounds for Group DRO in this framework. The work also shows how these theoretical insights translate into practice via R-MDL, a resampling-based On-Demand MD learning approach that outperforms GDRO and ERM in worst-group accuracy across several datasets and remains robust to regularization choices, underscoring the practical impact of adaptive data collection in multi-distribution settings.

Abstract

Social and real-world considerations such as robustness, fairness, social welfare and multi-agent tradeoffs have given rise to multi-distribution learning paradigms, such as collaborative learning, group distributionally robust optimization, and fair federated learning. In each of these settings, a learner seeks to uniformly minimize its expected loss over $n$ predefined data distributions, while using as few samples as possible. In this paper, we establish the optimal sample complexity of these learning paradigms and give algorithms that meet this sample complexity. Importantly, our sample complexity bounds for multi-distribution learning exceed that of learning a single distribution by only an additive factor of $n \log(n) / ε^2$. This improves upon the best known sample complexity bounds for fair federated learning by Mohri et al. and collaborative learning by Nguyen and Zakynthinou by multiplicative factors of $n$ and $\log(n)/ε^3$, respectively. We also provide the first sample complexity bounds for the group DRO objective of Sagawa et al. To guarantee these optimal sample complexity bounds, our algorithms learn to sample from data distributions on demand. Our algorithm design and analysis are enabled by our extensions of online learning techniques for solving stochastic zero-sum games. In particular, we contribute stochastic variants of no-regret dynamics that can trade off between players' differing sampling costs.

On-Demand Sampling: Learning Optimally from Multiple Distributions

TL;DR

The paper tackles learning across multiple data distributions by introducing on-demand sampling to achieve optimal additive sample complexity overheads relative to single-distribution learning. By casting multi-distribution learning as a stochastic convex-concave zero-sum game between a hypothesis player and a distribution-loss auditor, and solving it with no-regret online dynamics using biased but bounded stochastic gradient estimates, it derives tight upper and matching lower bounds for collaborative learning, group DRO, and agnostic federated learning. The key contributions include proving optimal additive sharing bounds of the form for finite hypothesis classes, extending results to Group DRO with a single smooth loss, and providing the first sample complexity bounds for Group DRO in this framework. The work also shows how these theoretical insights translate into practice via R-MDL, a resampling-based On-Demand MD learning approach that outperforms GDRO and ERM in worst-group accuracy across several datasets and remains robust to regularization choices, underscoring the practical impact of adaptive data collection in multi-distribution settings.

Abstract

Social and real-world considerations such as robustness, fairness, social welfare and multi-agent tradeoffs have given rise to multi-distribution learning paradigms, such as collaborative learning, group distributionally robust optimization, and fair federated learning. In each of these settings, a learner seeks to uniformly minimize its expected loss over predefined data distributions, while using as few samples as possible. In this paper, we establish the optimal sample complexity of these learning paradigms and give algorithms that meet this sample complexity. Importantly, our sample complexity bounds for multi-distribution learning exceed that of learning a single distribution by only an additive factor of . This improves upon the best known sample complexity bounds for fair federated learning by Mohri et al. and collaborative learning by Nguyen and Zakynthinou by multiplicative factors of and , respectively. We also provide the first sample complexity bounds for the group DRO objective of Sagawa et al. To guarantee these optimal sample complexity bounds, our algorithms learn to sample from data distributions on demand. Our algorithm design and analysis are enabled by our extensions of online learning techniques for solving stochastic zero-sum games. In particular, we contribute stochastic variants of no-regret dynamics that can trade off between players' differing sampling costs.
Paper Structure (49 sections, 19 theorems, 23 equations, 1 figure, 2 tables, 3 algorithms)

This paper contains 49 sections, 19 theorems, 23 equations, 1 figure, 2 tables, 3 algorithms.

Key Result

Lemma 2.0

Let ${c}\newline^{\left(1:T\right)}$ be any linear cost sequence where $\max_{t \in [T]} \left\lVert{c}\newline^{\left(t\right)}\right\rVert_\infty \leq 1$ and $A = \Delta_n$. When $\eta = \sqrt{ \log (n / T)}$, the actions $a\newline^{\left(1:T\right)}$ chosen by Hedge satisfy ${\mathrm{Reg}(a\newl

Figures (1)

  • Figure 1: Training (light, dashed) and validation (dark, solid) accuracies for GDRO and R-MDL during training, plotted on a log scale. Note that R-MDL validation accuracy will be noisier than those of GDRO as we constrain R-MDL to limited samples (with replacement) from the validation set. In addition, in the left-most plot, training accuracy for all groups except the blond male group (red) dips to zero due to lack of data---this is because the blond male group (red) is the most challenging so the adversary eventually stops sampling from other groups. Under standard regularization, the red-group accuracy drops off in GDRO while R-MDL maintains a high red-group accuracy by heavily sampling from the red group, as reflected in the near-perfect red-group training error.

Theorems & Definitions (39)

  • Lemma 2.0: vishnoi_algorithms_2021
  • Lemma 2.1: mannor2011bandits
  • proof
  • Lemma 3.0
  • proof
  • proof
  • proof
  • proof : Proof of Lemma \ref{['lemma:zerosumasymmetricfull']}
  • Theorem 4.1
  • proof
  • ...and 29 more