Table of Contents
Fetching ...

Data Augmentations for Improved (Large) Language Model Generalization

Amir Feder, Yoav Wald, Claudia Shi, Suchi Saria, David Blei

TL;DR

This work targets spurious correlations in text classification under distribution shift, notably in healthcare. It develops counterfactual data augmentation guided by causal structure, using auxiliary data and LLMs to generate text interventions across caregivers, thereby de-correlating writing style from patient condition. The authors formalize the approach, compare its sample efficiency to reweighting, and provide practical algorithms (CATO) using matching and diff-in-diff mechanisms. Empirical results on clinical narratives and semi-synthetic data show improved OOD generalization over invariant-learning baselines, highlighting the method's potential for safer, more robust NLP in safety-critical domains.

Abstract

The reliance of text classifiers on spurious correlations can lead to poor generalization at deployment, raising concerns about their use in safety-critical domains such as healthcare. In this work, we propose to use counterfactual data augmentation, guided by knowledge of the causal structure of the data, to simulate interventions on spurious features and to learn more robust text classifiers. We show that this strategy is appropriate in prediction problems where the label is spuriously correlated with an attribute. Under the assumptions of such problems, we discuss the favorable sample complexity of counterfactual data augmentation, compared to importance re-weighting. Pragmatically, we match examples using auxiliary data, based on diff-in-diff methodology, and use a large language model (LLM) to represent a conditional probability of text. Through extensive experimentation on learning caregiver-invariant predictors of clinical diagnoses from medical narratives and on semi-synthetic data, we demonstrate that our method for simulating interventions improves out-of-distribution (OOD) accuracy compared to baseline invariant learning algorithms.

Data Augmentations for Improved (Large) Language Model Generalization

TL;DR

This work targets spurious correlations in text classification under distribution shift, notably in healthcare. It develops counterfactual data augmentation guided by causal structure, using auxiliary data and LLMs to generate text interventions across caregivers, thereby de-correlating writing style from patient condition. The authors formalize the approach, compare its sample efficiency to reweighting, and provide practical algorithms (CATO) using matching and diff-in-diff mechanisms. Empirical results on clinical narratives and semi-synthetic data show improved OOD generalization over invariant-learning baselines, highlighting the method's potential for safer, more robust NLP in safety-critical domains.

Abstract

The reliance of text classifiers on spurious correlations can lead to poor generalization at deployment, raising concerns about their use in safety-critical domains such as healthcare. In this work, we propose to use counterfactual data augmentation, guided by knowledge of the causal structure of the data, to simulate interventions on spurious features and to learn more robust text classifiers. We show that this strategy is appropriate in prediction problems where the label is spuriously correlated with an attribute. Under the assumptions of such problems, we discuss the favorable sample complexity of counterfactual data augmentation, compared to importance re-weighting. Pragmatically, we match examples using auxiliary data, based on diff-in-diff methodology, and use a large language model (LLM) to represent a conditional probability of text. Through extensive experimentation on learning caregiver-invariant predictors of clinical diagnoses from medical narratives and on semi-synthetic data, we demonstrate that our method for simulating interventions improves out-of-distribution (OOD) accuracy compared to baseline invariant learning algorithms.
Paper Structure (32 sections, 2 theorems, 19 equations, 7 figures, 3 tables, 1 algorithm)

This paper contains 32 sections, 2 theorems, 19 equations, 7 figures, 3 tables, 1 algorithm.

Key Result

Lemma 1

For the prediction problem in def:prob_setting, the Bayes optimal classifier under the unconfounded distribution $P_\bot\in{{\mathcal{P}}}$ where $C$ is uniformly distributed and independent of $Y$ is $h^*({\mathbf{x}}) = \mathrm{arg}\max_{y\in{[K]}} P_\bot(Y=y \mid X^*=e({\mathbf{x}}))$. It is a mi

Figures (7)

  • Figure 1: Prediction problem with a spuriously correlated attribute.
  • Figure 2: Generating counterfactual clinical notes for patients using auxiliary data with \ref{['alg:cdaug']}(A).
  • Figure 3: Results ($F1$ averaged across 5 runs) for predicting clinical conditions (A) and for clinical note segmentation (B) from the text narratives. CATO (A) outperforms all baselines on OOD data.
  • Figure 4: OOD accuracy ($1-{\mathcal{R}}^{l_{01}}_{P_{\bot}}(h)$) and $Y,C$ correlation strength ($I(Y ; C)$). Lower values of $\lambda$ correspond to stronger corruptions of the augmentations. Even with substantial corruption ($\lambda=0.2$) and strong correlation, augmentations outperform baselines.
  • Figure 5: Possible causal structures that involve the auxiliary data $M$, where unobserved $M$ corresponds to unobserved confounding between $X$ and $C$.
  • ...and 2 more figures

Theorems & Definitions (6)

  • Definition 1
  • Lemma 1
  • Definition 2
  • Lemma 2
  • proof
  • proof