Table of Contents
Fetching ...

Does TabPFN Understand Causal Structures?

Omar Swelam, Lennart Purucker, Jake Robertson, Hanne Raum, Joschka Boedecker, Frank Hutter

TL;DR

This work asks whether a tabular foundation model pretrained on synthetic causal data encodes causal structure in its representations. It introduces an adapter with a learnable dual-attention decoder and universal tokens to extract causal signals from TabPFN's frozen embeddings and decode them into adjacency matrices. The study shows that TabPFN embeddings contain causal information, concentrated in mid-range layers, and that the approach can outperform traditional causal discovery algorithms on synthetic benchmarks. Overall, the results point to a promising direction for using pre-trained tabular models to support interpretable and adaptable causal discovery across domains.

Abstract

Causal discovery is fundamental for multiple scientific domains, yet extracting causal information from real world data remains a significant challenge. Given the recent success on real data, we investigate whether TabPFN, a transformer-based tabular foundation model pre-trained on synthetic datasets generated from structural causal models, encodes causal information in its internal representations. We develop an adapter framework using a learnable decoder and causal tokens that extract causal signals from TabPFN's frozen embeddings and decode them into adjacency matrices for causal discovery. Our evaluations demonstrate that TabPFN's embeddings contain causal information, outperforming several traditional causal discovery algorithms, with such causal information being concentrated in mid-range layers. These findings establish a new direction for interpretable and adaptable foundation models and demonstrate the potential for leveraging pre-trained tabular models for causal discovery.

Does TabPFN Understand Causal Structures?

TL;DR

This work asks whether a tabular foundation model pretrained on synthetic causal data encodes causal structure in its representations. It introduces an adapter with a learnable dual-attention decoder and universal tokens to extract causal signals from TabPFN's frozen embeddings and decode them into adjacency matrices. The study shows that TabPFN embeddings contain causal information, concentrated in mid-range layers, and that the approach can outperform traditional causal discovery algorithms on synthetic benchmarks. Overall, the results point to a promising direction for using pre-trained tabular models to support interpretable and adaptable causal discovery across domains.

Abstract

Causal discovery is fundamental for multiple scientific domains, yet extracting causal information from real world data remains a significant challenge. Given the recent success on real data, we investigate whether TabPFN, a transformer-based tabular foundation model pre-trained on synthetic datasets generated from structural causal models, encodes causal information in its internal representations. We develop an adapter framework using a learnable decoder and causal tokens that extract causal signals from TabPFN's frozen embeddings and decode them into adjacency matrices for causal discovery. Our evaluations demonstrate that TabPFN's embeddings contain causal information, outperforming several traditional causal discovery algorithms, with such causal information being concentrated in mid-range layers. These findings establish a new direction for interpretable and adaptable foundation models and demonstrate the potential for leveraging pre-trained tabular models for causal discovery.

Paper Structure

This paper contains 32 sections, 2 equations, 7 figures, 1 table.

Figures (7)

  • Figure 1: Overall architecture of our approach, where the data embeddings from the frozen TabPFN (left) are attended to in the decoder (middle) to extract aggregated feature-representations for the adjacency matrix prediction (right)
  • Figure 2: While we outperform statistical baselines (in green), we perform closely to AVICI on ROC AUC (left) in a stable manner, yet witness an increasing degradation in AP at scale (right).
  • Figure 3: AP scores of our approach trained using different encoder/decoder layers (left) and different initializations/embeddings (right), showing that the middle layers and that embeddings from TabPFN's encoder of optimal weights encode better causal information.
  • Figure 4: Samples of different graphs sampled from the graph structures. Each plot represents the adjacency matrix of a sampled DAG, where the yellow dots represent the edges.
  • Figure 5: Different decoder architectures.
  • ...and 2 more figures