Table of Contents
Fetching ...

CausalFormer: An Interpretable Transformer for Temporal Causal Discovery

Lingbai Kong, Wengen Li, Hanchen Yang, Yichao Zhang, Jihong Guan, Shuigeng Zhou

TL;DR

CausalFormer addresses temporal causal discovery by moving beyond partial-model interpretability to a holistic, interpretable framework. It introduces a causality-aware transformer with multi-kernel causal convolution to learn causal representations while respecting temporal priority, and a regression relevance propagation-based detector to extract causal edges and delays from the entire model. The approach achieves state-of-the-art performance across synthetic, simulated, and real datasets, including accurate causal graph construction and reasonable delay estimation. This work advances practical temporal causality analysis by providing global interpretability and robust causal graph discovery applicable to domains like neuroscience and climate science.

Abstract

Temporal causal discovery is a crucial task aimed at uncovering the causal relations within time series data. The latest temporal causal discovery methods usually train deep learning models on prediction tasks to uncover the causality between time series. They capture causal relations by analyzing the parameters of some components of the trained models, e.g., attention weights and convolution weights. However, this is an incomplete mapping process from the model parameters to the causality and fails to investigate the other components, e.g., fully connected layers and activation functions, that are also significant for causal discovery. To facilitate the utilization of the whole deep learning models in temporal causal discovery, we proposed an interpretable transformer-based causal discovery model termed CausalFormer, which consists of the causality-aware transformer and the decomposition-based causality detector. The causality-aware transformer learns the causal representation of time series data using a prediction task with the designed multi-kernel causal convolution which aggregates each input time series along the temporal dimension under the temporal priority constraint. Then, the decomposition-based causality detector interprets the global structure of the trained causality-aware transformer with the proposed regression relevance propagation to identify potential causal relations and finally construct the causal graph. Experiments on synthetic, simulated, and real datasets demonstrate the state-of-the-art performance of CausalFormer on discovering temporal causality. Our code is available at https://github.com/lingbai-kong/CausalFormer.

CausalFormer: An Interpretable Transformer for Temporal Causal Discovery

TL;DR

CausalFormer addresses temporal causal discovery by moving beyond partial-model interpretability to a holistic, interpretable framework. It introduces a causality-aware transformer with multi-kernel causal convolution to learn causal representations while respecting temporal priority, and a regression relevance propagation-based detector to extract causal edges and delays from the entire model. The approach achieves state-of-the-art performance across synthetic, simulated, and real datasets, including accurate causal graph construction and reasonable delay estimation. This work advances practical temporal causality analysis by providing global interpretability and robust causal graph discovery applicable to domains like neuroscience and climate science.

Abstract

Temporal causal discovery is a crucial task aimed at uncovering the causal relations within time series data. The latest temporal causal discovery methods usually train deep learning models on prediction tasks to uncover the causality between time series. They capture causal relations by analyzing the parameters of some components of the trained models, e.g., attention weights and convolution weights. However, this is an incomplete mapping process from the model parameters to the causality and fails to investigate the other components, e.g., fully connected layers and activation functions, that are also significant for causal discovery. To facilitate the utilization of the whole deep learning models in temporal causal discovery, we proposed an interpretable transformer-based causal discovery model termed CausalFormer, which consists of the causality-aware transformer and the decomposition-based causality detector. The causality-aware transformer learns the causal representation of time series data using a prediction task with the designed multi-kernel causal convolution which aggregates each input time series along the temporal dimension under the temporal priority constraint. Then, the decomposition-based causality detector interprets the global structure of the trained causality-aware transformer with the proposed regression relevance propagation to identify potential causal relations and finally construct the causal graph. Experiments on synthetic, simulated, and real datasets demonstrate the state-of-the-art performance of CausalFormer on discovering temporal causality. Our code is available at https://github.com/lingbai-kong/CausalFormer.
Paper Structure (25 sections, 21 equations, 10 figures, 3 tables)

This paper contains 25 sections, 21 equations, 10 figures, 3 tables.

Figures (10)

  • Figure 1: One example of temporal causality with a diamond causal structure, where the numbers associated with edges are time lags of causal relations.
  • Figure 2: The workflow of CausalFormer.
  • Figure 3: The structure of CausalFormer, where the causality-aware transformer (a) learns the causal representation of time series with sequential layers; the decomposition-based causality detector (b) backward propagates both the relevance scores and gradients to the attention matrix and the causal convolution kernels, selects the true causal relations by clustering, and finally outputs the temporal causal graph; (c) illustrates the internal structure of multi-kernel causal convolution block; and (d) demonstrates the regression relevance propagation process for time series.
  • Figure 4: The computational flow of layer-wised relevance propagation (LPR) on single time series input. The prediction function $f(\boldsymbol{x})$ first makes the classification with the given observation $\boldsymbol{x}$. Then the output neuron is assigned relevance $\boldsymbol{R}^{(L)}$$=$$f(\boldsymbol{x})$, which is decomposed as a sum of terms called relevance scores $\boldsymbol{R}$ and backward propagated to the neurons of each layer. The neurons with a high relevance score are colored in blue.
  • Figure 5: The searching process of the nearest root point $\boldsymbol{x}_0$ (empty cycle) of the interpreted input $\boldsymbol{x}$ (solid cycle).
  • ...and 5 more figures