Table of Contents
Fetching ...

Synergizing Deconfounding and Temporal Generalization For Time-series Counterfactual Outcome Estimation

Yiling Liu, Juncheng Dong, Chen Fu, Wei Shi, Ziyang Jiang, Zhigang Hua, David Carlson

TL;DR

The paper tackles time-series counterfactual outcome estimation under time-varying confounding by introducing Sub-treatment Group Alignment (SGA) to achieve finer-grained deconfounding and Random Temporal Masking (RTM) to boost temporal generalization. The authors derive a tighter counterfactual risk bound via SGA and demonstrate that RTM encourages reliance on stable historical patterns, improving long-horizon predictions. Empirical results on fully synthetic PK-PD tumor growth and semi-synthetic MIMIC-based data show state-of-the-art performance, with ablations confirming complementary benefits of SGA and RTM. The framework is architecture-agnostic and can be integrated with CRN or CT, offering a practical and robust approach to causal inference in observational time series. Together, SGA and RTM provide a flexible, scalable method to better estimate time-series counterfactuals in the presence of time-varying confounding and evolving covariates.

Abstract

Estimating counterfactual outcomes from time-series observations is crucial for effective decision-making, e.g. when to administer a life-saving treatment, yet remains significantly challenging because (i) the counterfactual trajectory is never observed and (ii) confounders evolve with time and distort estimation at every step. To address these challenges, we propose a novel framework that synergistically integrates two complementary approaches: Sub-treatment Group Alignment (SGA) and Random Temporal Masking (RTM). Instead of the coarse practice of aligning marginal distributions of the treatments in latent space, SGA uses iterative treatment-agnostic clustering to identify fine-grained sub-treatment groups. Aligning these fine-grained groups achieves improved distributional matching, thus leading to more effective deconfounding. We theoretically demonstrate that SGA optimizes a tighter upper bound on counterfactual risk and empirically verify its deconfounding efficacy. RTM promotes temporal generalization by randomly replacing input covariates with Gaussian noises during training. This encourages the model to rely less on potentially noisy or spuriously correlated covariates at the current step and more on stable historical patterns, thereby improving its ability to generalize across time and better preserve underlying causal relationships. Our experiments demonstrate that while applying SGA and RTM individually improves counterfactual outcome estimation, their synergistic combination consistently achieves state-of-the-art performance. This success comes from their distinct yet complementary roles: RTM enhances temporal generalization and robustness across time steps, while SGA improves deconfounding at each specific time point.

Synergizing Deconfounding and Temporal Generalization For Time-series Counterfactual Outcome Estimation

TL;DR

The paper tackles time-series counterfactual outcome estimation under time-varying confounding by introducing Sub-treatment Group Alignment (SGA) to achieve finer-grained deconfounding and Random Temporal Masking (RTM) to boost temporal generalization. The authors derive a tighter counterfactual risk bound via SGA and demonstrate that RTM encourages reliance on stable historical patterns, improving long-horizon predictions. Empirical results on fully synthetic PK-PD tumor growth and semi-synthetic MIMIC-based data show state-of-the-art performance, with ablations confirming complementary benefits of SGA and RTM. The framework is architecture-agnostic and can be integrated with CRN or CT, offering a practical and robust approach to causal inference in observational time series. Together, SGA and RTM provide a flexible, scalable method to better estimate time-series counterfactuals in the presence of time-varying confounding and evolving covariates.

Abstract

Estimating counterfactual outcomes from time-series observations is crucial for effective decision-making, e.g. when to administer a life-saving treatment, yet remains significantly challenging because (i) the counterfactual trajectory is never observed and (ii) confounders evolve with time and distort estimation at every step. To address these challenges, we propose a novel framework that synergistically integrates two complementary approaches: Sub-treatment Group Alignment (SGA) and Random Temporal Masking (RTM). Instead of the coarse practice of aligning marginal distributions of the treatments in latent space, SGA uses iterative treatment-agnostic clustering to identify fine-grained sub-treatment groups. Aligning these fine-grained groups achieves improved distributional matching, thus leading to more effective deconfounding. We theoretically demonstrate that SGA optimizes a tighter upper bound on counterfactual risk and empirically verify its deconfounding efficacy. RTM promotes temporal generalization by randomly replacing input covariates with Gaussian noises during training. This encourages the model to rely less on potentially noisy or spuriously correlated covariates at the current step and more on stable historical patterns, thereby improving its ability to generalize across time and better preserve underlying causal relationships. Our experiments demonstrate that while applying SGA and RTM individually improves counterfactual outcome estimation, their synergistic combination consistently achieves state-of-the-art performance. This success comes from their distinct yet complementary roles: RTM enhances temporal generalization and robustness across time steps, while SGA improves deconfounding at each specific time point.

Paper Structure

This paper contains 35 sections, 10 theorems, 44 equations, 9 figures, 9 tables, 1 algorithm.

Key Result

Theorem 4.1

Let $\Phi:\mathcal{X} \rightarrow \mathcal{R}$ be a one-to-one and Jacobian-normalized representation function. Let $h : R \times \{0, 1\} \rightarrow Y$ be a hypothesis with Lipschitz constant: where $B_\Phi$ is a constant and $p_{\Phi}^{a}$ is the distribution of the random variable $\Phi(X)$ conditioned on $A=a$, that is, representations for individuals receiving treatment $a \in \{0,1\}$.

Figures (9)

  • Figure 1: Conceptual overview of SGA and RTM.(a) SGA identifies and aligns fine-grained sub-treatment groups at each timestep to improve deconfounding. (b) RTM forces the model to leverage historical patterns and enhancing temporal generalization. (c) SGA and RTM are synergistically combined to improve counterfactual outcome estimation. Here, $k$ denotes sub-treatment group index.
  • Figure 2: Overview of SGA & RTM at each timepoint. For simplicity, we show a binary treatment scenario.
  • Figure 3: Performances on $\tau$-step-ahead ($\tau$=6) prediction. Note that CT ($\alpha$=0) refers to CT w/o alignment.
  • Figure 4: Heatmap of self-attention weights.Left: attention to past time points. Rightmost columns: sum of total past attention (Sum of prev.), and attention to the current time point (t=58).
  • Figure 5: Causal Directed Acyclic Graphs (DAGs) Illustrating Causal Relationships.(a) demonstrate a static (non-time-series) scenario. (b) illustrates a time-series scenario.
  • ...and 4 more figures

Theorems & Definitions (26)

  • Theorem 4.1: Simplified Lemma A8 from shalit2017estimating, complete version provided in Appendix \ref{['thm:complete_theorem_1']}.
  • Theorem 4.2: SGA Improves Generalization Bounds
  • Remark 4.3
  • Definition C.1: Definition A4 in shalit2017estimating
  • Definition C.1: Definition A4 in shalit2017estimating
  • Definition C.2: Definition A5 in shalit2017estimating
  • Definition C.3
  • Definition C.6: Definition A12 in shalit2017estimating
  • Definition C.7: Definition A13 in shalit2017estimating
  • Remark C.8
  • ...and 16 more