Table of Contents
Fetching ...

Multiply Robust Estimation for Local Distribution Shifts with Multiple Domains

Steven Wilkins-Reeves, Xu Chen, Qi Ma, Christine Agarwal, Aude Hofleitner

TL;DR

The paper tackles distribution shifts that vary across population segments by relaxing global shift assumptions to local shifts within each segment. It introduces a two-stage Multiply Robust (MR) estimation framework that clusters segments to train base models, derives a Stage-1 domain-specific linear combination, and refines predictions with a Stage-2 weighted estimator, backed by a formal generalization bound. Theoretical results decompose error into weight-estimation and function-estimation components and provide convergence rates under covariate and label shift weight estimators. Empirically, MR outperforms state-of-the-art domain adaptation baselines across simulated and real datasets, including a Meta user-city dataset, demonstrating improved prediction accuracy and robustness for tabular data under local distribution shifts.

Abstract

Distribution shifts are ubiquitous in real-world machine learning applications, posing a challenge to the generalization of models trained on one data distribution to another. We focus on scenarios where data distributions vary across multiple segments of the entire population and only make local assumptions about the differences between training and test (deployment) distributions within each segment. We propose a two-stage multiply robust estimation method to improve model performance on each individual segment for tabular data analysis. The method involves fitting a linear combination of the based models, learned using clusters of training data from multiple segments, followed by a refinement step for each segment. Our method is designed to be implemented with commonly used off-the-shelf machine learning models. We establish theoretical guarantees on the generalization bound of the method on the test risk. With extensive experiments on synthetic and real datasets, we demonstrate that the proposed method substantially improves over existing alternatives in prediction accuracy and robustness on both regression and classification tasks. We also assess its effectiveness on a user city prediction dataset from Meta.

Multiply Robust Estimation for Local Distribution Shifts with Multiple Domains

TL;DR

The paper tackles distribution shifts that vary across population segments by relaxing global shift assumptions to local shifts within each segment. It introduces a two-stage Multiply Robust (MR) estimation framework that clusters segments to train base models, derives a Stage-1 domain-specific linear combination, and refines predictions with a Stage-2 weighted estimator, backed by a formal generalization bound. Theoretical results decompose error into weight-estimation and function-estimation components and provide convergence rates under covariate and label shift weight estimators. Empirically, MR outperforms state-of-the-art domain adaptation baselines across simulated and real datasets, including a Meta user-city dataset, demonstrating improved prediction accuracy and robustness for tabular data under local distribution shifts.

Abstract

Distribution shifts are ubiquitous in real-world machine learning applications, posing a challenge to the generalization of models trained on one data distribution to another. We focus on scenarios where data distributions vary across multiple segments of the entire population and only make local assumptions about the differences between training and test (deployment) distributions within each segment. We propose a two-stage multiply robust estimation method to improve model performance on each individual segment for tabular data analysis. The method involves fitting a linear combination of the based models, learned using clusters of training data from multiple segments, followed by a refinement step for each segment. Our method is designed to be implemented with commonly used off-the-shelf machine learning models. We establish theoretical guarantees on the generalization bound of the method on the test risk. With extensive experiments on synthetic and real datasets, we demonstrate that the proposed method substantially improves over existing alternatives in prediction accuracy and robustness on both regression and classification tasks. We also assess its effectiveness on a user city prediction dataset from Meta.
Paper Structure (32 sections, 6 theorems, 38 equations, 4 figures, 15 tables, 2 algorithms)

This paper contains 32 sections, 6 theorems, 38 equations, 4 figures, 15 tables, 2 algorithms.

Key Result

Theorem 5.3

Suppose that assumption: regularity conditions holds. Then denote the 2-stage estimator illustrated in alg:MR_estimator, where $\nu' = \lVert \beta^{*(s) \intercal}\mathbf{h} - f^{*(s)}_{\nu_{q,s}} \rVert _{\mathcal{F}} + C\frac{M_F}{\tilde{\lambda}_1}\sqrt{\frac{M \log(e/\delta)}{n_s}}$ and $C$ is a universal constant. For $n \geq M_{eff} M \log(e M / \delta)$ then with probability at least $1

Figures (4)

  • Figure 1: Schematic representation of the problem setup with two segments $\{S_1, S_2\}$. We observe both features and responses for the training data and only features for the test domain. We don't make any global assumptions of distribution shifts but only shifts within a segment.
  • Figure 2: Performance of the proposed MR method against the competitors in the simulation study. Error bars indicate 99% confidence intervals.
  • Figure 3: Test CE relative to the production model (black dotted line) by country. Error bars denote 99% confidence intervals.
  • Figure 4: Data splitting for the multiply robust estimator. We use a holdout set for refining the linear combination, and train the second stage model on the whole training fold segment.

Theorems & Definitions (16)

  • Remark 4.1
  • Remark 5.2
  • Theorem 5.3: Generalization Bound For a Multiply Robust Estimator
  • Remark 5.4: Error decomposition
  • Remark 5.5: Comparison with global estimators
  • Remark 5.6: Comparison with doubly robust estimator
  • Remark 5.7: Curse of dimensionality
  • Lemma 2.1: Best Linear Combination As A Prior
  • proof
  • Lemma 2.2: Lemma 2 of reddi2015doubly
  • ...and 6 more