Table of Contents
Fetching ...

COSTAR: Improved Temporal Counterfactual Estimation with Self-Supervised Learning

Chuizheng Meng, Yihe Dong, Sercan Ö. Arık, Yan Liu, Tomas Pfister

TL;DR

COSTAR tackles temporal counterfactual outcome estimation under time-varying confounding and distribution shifts by learning expressive history representations through self-supervised learning and a Transformer encoder that jointly models temporal and feature interactions. It introduces component-wise contrastive losses, a non-autoregressive future predictor, and an unsupervised domain adaptation perspective to bound transfer errors, enabling effective zero-shot and data-efficient transfer. Empirical results on synthetic and real-world datasets show COSTAR outperforms baselines in estimation accuracy and cross-domain generalization, while ablations highlight the importance of the encoder design and SSL losses. The approach offers practical benefits for decision-making in domains where RCTs are costly or impractical and data from target populations are scarce.

Abstract

Estimation of temporal counterfactual outcomes from observed history is crucial for decision-making in many domains such as healthcare and e-commerce, particularly when randomized controlled trials (RCTs) suffer from high cost or impracticality. For real-world datasets, modeling time-dependent confounders is challenging due to complex dynamics, long-range dependencies and both past treatments and covariates affecting the future outcomes. In this paper, we introduce Counterfactual Self-Supervised Transformer (COSTAR), a novel approach that integrates self-supervised learning for improved historical representations. We propose a component-wise contrastive loss tailored for temporal treatment outcome observations and explain its effectiveness from the view of unsupervised domain adaptation. COSTAR yields superior performance in estimation accuracy and generalization to out-of-distribution data compared to existing models, as validated by empirical results on both synthetic and real-world datasets.

COSTAR: Improved Temporal Counterfactual Estimation with Self-Supervised Learning

TL;DR

COSTAR tackles temporal counterfactual outcome estimation under time-varying confounding and distribution shifts by learning expressive history representations through self-supervised learning and a Transformer encoder that jointly models temporal and feature interactions. It introduces component-wise contrastive losses, a non-autoregressive future predictor, and an unsupervised domain adaptation perspective to bound transfer errors, enabling effective zero-shot and data-efficient transfer. Empirical results on synthetic and real-world datasets show COSTAR outperforms baselines in estimation accuracy and cross-domain generalization, while ablations highlight the importance of the encoder design and SSL losses. The approach offers practical benefits for decision-making in domains where RCTs are costly or impractical and data from target populations are scarce.

Abstract

Estimation of temporal counterfactual outcomes from observed history is crucial for decision-making in many domains such as healthcare and e-commerce, particularly when randomized controlled trials (RCTs) suffer from high cost or impracticality. For real-world datasets, modeling time-dependent confounders is challenging due to complex dynamics, long-range dependencies and both past treatments and covariates affecting the future outcomes. In this paper, we introduce Counterfactual Self-Supervised Transformer (COSTAR), a novel approach that integrates self-supervised learning for improved historical representations. We propose a component-wise contrastive loss tailored for temporal treatment outcome observations and explain its effectiveness from the view of unsupervised domain adaptation. COSTAR yields superior performance in estimation accuracy and generalization to out-of-distribution data compared to existing models, as validated by empirical results on both synthetic and real-world datasets.
Paper Structure (45 sections, 3 theorems, 26 equations, 4 figures, 13 tables, 1 algorithm)

This paper contains 45 sections, 3 theorems, 26 equations, 4 figures, 13 tables, 1 algorithm.

Key Result

Theorem 4.1

Suppose that Assumptions asm:cross-cluster, asm:intra-cluster, and asm:relative-expansion hold for the set of observed history $\mathcal{H}$ and its positive-pair graph $G(\mathcal{H}, w)$, and the representation dimension $k\geq 2m$. Let $r$ be a minimizer of the generalized spectral contrastive lo

Figures (4)

  • Figure 1: We illustrate the problem of treatment outcome estimation over time with an example in healthcare. We propose COSTAR as a temporal counterfactual estimator enhanced with self-supervised learning, inducing transferrability to both ① cold-start cases from unseen subpopulations and ② counterfactual outcome estimation.
  • Figure 2: Overview of COSTAR. (a) Encoder architecture. The Temporal Attention Block applies temporal causal attention along the time dimension in parallel for each feature, while the Feature-wise Attention Block calculates full self-attention along the feature dimension in all time steps. (b) Self-supervised learning of the history representations. Positive pairs are generated by applying random transformations $\mathcal{T}(\cdot)$ on the same sample. We construct component-wise contrastive losses of historical covariates, treatments and outcomes in addition to the standard contrastive loss of the entire sequence. (c) Non-autoregressive outcome predictor architecture.
  • Figure 3: T-SNE visualization of learned representations in Semi-synthetic MIMIC-III dataset.
  • Figure 4: Examples of counterfactual treatment outcome estimation with semi-synthetic MIMIC-III data in the zero-shot transfer setting. We plot one of the two output dimensions for clarity. Each row lists the results of a counterfactual treatment sequence, while each column shows the estimations of one method across all treatment sequences tested. In each sub-figure, the observed historical outcomes are plotted in black solid lines, and the ground truth counterfactual outcomes in black dash lines. The blue solid lines show the estimated outcomes.

Theorems & Definitions (5)

  • Theorem 4.1: Upper bound of counterfactual outcome estimator
  • Definition D.1: Expansion
  • Theorem D.5: Upper bound of 0-1 error on the target domain haochen2022beyond
  • Lemma D.6: Relation between the L2 regression error and 0-1 classification error
  • proof