Table of Contents
Fetching ...

CTRL Your Shift: Clustered Transfer Residual Learning for Many Small Datasets

Gauri Jain, Dominik Rothenhäusler, Kirk Bansak, Elisabeth Paulson

TL;DR

This paper introduces Clustered Transfer Residual Learning (CTRL), a meta-learning method that combines the strengths of cross-domain residual learning and adaptive pooling/clustering in order to simultaneously improve overall accuracy and preserve source-level heterogeneity.

Abstract

Machine learning (ML) tasks often utilize large-scale data that is drawn from several distinct sources, such as different locations, treatment arms, or groups. In such settings, practitioners often desire predictions that not only exhibit good overall accuracy, but also remain reliable within each source and preserve the differences that matter across sources. For instance, several asylum and refugee resettlement programs now use ML-based employment predictions to guide where newly arriving families are placed within a host country, which requires generating informative and differentiated predictions for many and often small source locations. However, this task is made challenging by several common characteristics of the data in these settings: the presence of numerous distinct data sources, distributional shifts between them, and substantial variation in sample sizes across sources. This paper introduces Clustered Transfer Residual Learning (CTRL), a meta-learning method that combines the strengths of cross-domain residual learning and adaptive pooling/clustering in order to simultaneously improve overall accuracy and preserve source-level heterogeneity. We establish new theory showing that high-quality clusters can be learned efficiently, bypassing the need for repeated model refitting over candidate subsets. We evaluate CTRL alongside other state-of-the-art benchmarks on 5 large-scale datasets. This includes a dataset from the national asylum program in Switzerland, where the algorithmic geographic assignment of asylum seekers is currently being piloted. CTRL consistently outperforms the benchmarks across several key metrics and when using a range of different base learners.

CTRL Your Shift: Clustered Transfer Residual Learning for Many Small Datasets

TL;DR

This paper introduces Clustered Transfer Residual Learning (CTRL), a meta-learning method that combines the strengths of cross-domain residual learning and adaptive pooling/clustering in order to simultaneously improve overall accuracy and preserve source-level heterogeneity.

Abstract

Machine learning (ML) tasks often utilize large-scale data that is drawn from several distinct sources, such as different locations, treatment arms, or groups. In such settings, practitioners often desire predictions that not only exhibit good overall accuracy, but also remain reliable within each source and preserve the differences that matter across sources. For instance, several asylum and refugee resettlement programs now use ML-based employment predictions to guide where newly arriving families are placed within a host country, which requires generating informative and differentiated predictions for many and often small source locations. However, this task is made challenging by several common characteristics of the data in these settings: the presence of numerous distinct data sources, distributional shifts between them, and substantial variation in sample sizes across sources. This paper introduces Clustered Transfer Residual Learning (CTRL), a meta-learning method that combines the strengths of cross-domain residual learning and adaptive pooling/clustering in order to simultaneously improve overall accuracy and preserve source-level heterogeneity. We establish new theory showing that high-quality clusters can be learned efficiently, bypassing the need for repeated model refitting over candidate subsets. We evaluate CTRL alongside other state-of-the-art benchmarks on 5 large-scale datasets. This includes a dataset from the national asylum program in Switzerland, where the algorithmic geographic assignment of asylum seekers is currently being piloted. CTRL consistently outperforms the benchmarks across several key metrics and when using a range of different base learners.

Paper Structure

This paper contains 38 sections, 2 theorems, 15 equations, 10 figures, 5 tables, 2 algorithms.

Key Result

proposition 1

Let $\mathcal{F}$ be the class of functions that are constant on leaves $L \in \mathcal{L}$, where $\mathcal{L}$ is a finite partition of $\mathcal{X}$ (e.g., as in regression trees with fixed splits). Let $\mathcal{F}'$ be a class of functions that are constant on leaves $B\in\mathcal{B}$, where $\ Fix weights $w\in\Delta^{|\mathcal{M}|}:=\{w_m\ge0,\sum_m w_m=1\}$ and define with final predictor

Figures (10)

  • Figure 1: An overview of CTRL on the example task of predicting educational attainment in the state of Alaska. CTRL's predictions are the sum of (1) predictions from a pooled model using all education data from the US and (2) predictions from a residual model that clusters location datasets from Alaska, Hawaii, Montana, and North Carolina.
  • Figure 2: Average model performance ranks across datasets.
  • Figure 3: Average performance gaps relative to CTRL are shown across all applicable datasets for each metric, excluding the synthetic dataset, which was specifically constructed for our use case. A positive value here indicates worse performance.
  • Figure 4: Change in top 5 locations picked as we add more iterations for running Algorithm \ref{['alg:xlearn_clustering']}. We reach about 0.25 changes per location after 150 iterations.
  • Figure 5: Change in ranks vector $\mathbf{k^*}$ returned by Algorithm \ref{['alg:cluster_selection']} for all locations $g \in \mathcal{M}$. L1 distance between $\mathbf{k^*}$ and $\mathbf{k^{*}}^{(+5)}$ corresponding to the $i$'th and $i+5$'th validation split.
  • ...and 5 more figures

Theorems & Definitions (2)

  • proposition 1: Link between residual optimization and CTRL risk
  • proposition 2: Excess risk under distribution shift