Table of Contents
Fetching ...

Frequentist Consistency of Prior-Data Fitted Networks for Causal Inference

Valentyn Melnychuk, Vahid Balazadeh, Stefan Feuerriegel, Rahul G. Krishnan

Abstract

Foundation models based on prior-data fitted networks (PFNs) have shown strong empirical performance in causal inference by framing the task as an in-context learning problem.However, it is unclear whether PFN-based causal estimators provide uncertainty quantification that is consistent with classical frequentist estimators. In this work, we address this gap by analyzing the frequentist consistency of PFN-based estimators for the average treatment effect (ATE). (1) We show that existing PFNs, when interpreted as Bayesian ATE estimators, can exhibit prior-induced confounding bias: the prior is not asymptotically overwritten by data, which, in turn, prevents frequentist consistency. (2) As a remedy, we suggest employing a calibration procedure based on a one-step posterior correction (OSPC). We show that the OSPC helps to restore frequentist consistency and can yield a semi-parametric Bernstein-von Mises theorem for calibrated PFNs (i.e., both the calibrated PFN-based estimators and the classical semi-parametric efficient estimators converge in distribution with growing data size). (3) Finally, we implement OSPC through tailoring martingale posteriors on top of the PFNs. In this way, we are able to recover functional nuisance posteriors from PFNs, required by the OSPC. In multiple (semi-)synthetic experiments, PFNs calibrated with our martingale posterior OSPC produce ATE uncertainty that (i) asymptotically matches frequentist uncertainty and (ii) is well calibrated in finite samples in comparison to other Bayesian ATE estimators.

Frequentist Consistency of Prior-Data Fitted Networks for Causal Inference

Abstract

Foundation models based on prior-data fitted networks (PFNs) have shown strong empirical performance in causal inference by framing the task as an in-context learning problem.However, it is unclear whether PFN-based causal estimators provide uncertainty quantification that is consistent with classical frequentist estimators. In this work, we address this gap by analyzing the frequentist consistency of PFN-based estimators for the average treatment effect (ATE). (1) We show that existing PFNs, when interpreted as Bayesian ATE estimators, can exhibit prior-induced confounding bias: the prior is not asymptotically overwritten by data, which, in turn, prevents frequentist consistency. (2) As a remedy, we suggest employing a calibration procedure based on a one-step posterior correction (OSPC). We show that the OSPC helps to restore frequentist consistency and can yield a semi-parametric Bernstein-von Mises theorem for calibrated PFNs (i.e., both the calibrated PFN-based estimators and the classical semi-parametric efficient estimators converge in distribution with growing data size). (3) Finally, we implement OSPC through tailoring martingale posteriors on top of the PFNs. In this way, we are able to recover functional nuisance posteriors from PFNs, required by the OSPC. In multiple (semi-)synthetic experiments, PFNs calibrated with our martingale posterior OSPC produce ATE uncertainty that (i) asymptotically matches frequentist uncertainty and (ii) is well calibrated in finite samples in comparison to other Bayesian ATE estimators.
Paper Structure (33 sections, 3 theorems, 43 equations, 10 figures, 5 tables)

This paper contains 33 sections, 3 theorems, 43 equations, 10 figures, 5 tables.

Key Result

Theorem 1

Assume the following holds for the observational data and a Bayesian estimator of the nuisance functions. Specifically, there exists a sequence of measurable subsets $H_n$ of $\mathcal{H}$ for which $\Pi(\tilde{\eta} \in H_n \mid \mathcal{D}) \to 1$ and for which (a)--(c) hold for $a \in \{0, 1\}$: (b) Uniform bounding: for large $n$, there exists $C, \varepsilon > 0$ such that, for all $\tilde{\

Figures (10)

  • Figure 1: Recovering functional posteriors from TabPFN. In (a), we show a PPD $P(y\mid \mathcal{D}, x)$, where $\mathcal{D}=\{(x_i,y_i)\}_{i=1}^{50}$. Then, in (b)-(d), we draw PPD samples from different functional posteriors $\tilde{\mu} \sim \Pi(\mu \mid \mathcal{D})$ (recovered with martingale posteriors), where $\mu(x) = \mathbb{E}(Y \mid x)$. Notably, the same PPD $P(y\mid \mathcal{D}, x)$(a) can encompass different functional posteriors (in the causal setting, those, in turn, lead to different ATE posteriors): (b)$x$-independent posterior, (c)$x$-parallel posterior, (d) smooth posterior.
  • Figure 2: Prior-induced confounding bias of different PFNs. High values of $\Delta$ correspond to a high degree of the observed confounding. Thus, when $\Delta$ is concentrated around zero, the prior might induce the confounding bias which does not vanish asymptotically with growing data (as a strong observed confounding is excluded a priori). Here, we sample $B = 512$ causal datasets with $n=10000$ each.
  • Figure 3: Overview of our MP-OSPC calibration procedure. Here, $N$ is the number of MP steps and $B$ is the number of posterior draws from the functional posteriors. Our MP-OSPC with PFNs thus yields the OSPC ATE posterior and serves as a Bayesian ATE estimator.
  • Figure 4: $L_2$-convergence check based on the synthetic data with varying size of the train data, $n_\text{train}$ (here: $d_x = 25$). Reported: mean $\hat{R}_2$$\pm$ se over 10 runs (lower is better). Note that both x- and y-axes are log-scaled.
  • Figure 5: Quality of the asymptotic uncertainty for Bayesian ATE estimators based on the synthetic data with (a) varying size of the train data, $n_\text{train}$, and (b) varying dimensionality of covariates, $d_x$. Reported: mean $\hat{d}_\text{TV}$$\pm$ se over 40 runs (lower is better).
  • ...and 5 more figures

Theorems & Definitions (7)

  • Definition 1: Semi-parametric Bernstein--von Misses (BvM) theorem
  • Theorem 1: Semi-parametric BvM theorem of the OSPC ATE posterior
  • proof
  • Proposition 1: Asymptotic variance of MP-based posteriors
  • proof
  • Proposition 1: Asymptotic variance of MP-based posteriors
  • proof