Table of Contents
Fetching ...

Causal-Aware Generative Adversarial Networks with Reinforcement Learning

Tu Anh Hoang Nguyen, Dang Nguyen, Tri-Nhan Vo, Thuc Duy Le, Sunil Gupta

TL;DR

CA-GAN addresses synthetic tabular data generation with explicit causal preservation by first extracting a causal graph $\mathcal{G}_{real}$ via a PC-based causal discovery and then training a graph-conditioned WGAN-GP with $M$ sub-generators. A reinforcement-learning-based causal loss uses a SHD-based reward to align the causal structure of synthetic data with the real data, optimizing through a policy-gradient that leverages log-likelihoods of mixed-type outputs. Across 14 datasets, CA-GAN achieves superior causal preservation, downstream utility, and privacy protection compared to six baselines, with competitive computation times. The approach enables practical, privacy-preserving synthetic data generation that maintains reliable causal inferences for enterprise analytics and secure research.

Abstract

The utility of tabular data for tasks ranging from model training to large-scale data analysis is often constrained by privacy concerns or regulatory hurdles. While existing data generation methods, particularly those based on Generative Adversarial Networks (GANs), have shown promise, they frequently struggle with capturing complex causal relationship, maintaining data utility, and providing provable privacy guarantees suitable for enterprise deployment. We introduce CA-GAN, a novel generative framework specifically engineered to address these challenges for real-world tabular datasets. CA-GAN utilizes a two-step approach: causal graph extraction to learn a robust, comprehensive causal relationship in the data's manifold, followed by a custom Conditional WGAN-GP (Wasserstein GAN with Gradient Penalty) that operates exclusively as per the structure of nodes in the causal graph. More importantly, the generator is trained with a new Reinforcement Learning-based objective that aligns the causal graphs constructed from real and fake data, ensuring the causal awareness in both training and sampling phases. We demonstrate CA-GAN superiority over six SOTA methods across 14 tabular datasets. Our evaluations, focused on core data engineering metrics: causal preservation, utility preservation, and privacy preservation. Our method offers a practical, high-performance solution for data engineers seeking to create high-quality, privacy-compliant synthetic datasets to benchmark database systems, accelerate software development, and facilitate secure data-driven research.

Causal-Aware Generative Adversarial Networks with Reinforcement Learning

TL;DR

CA-GAN addresses synthetic tabular data generation with explicit causal preservation by first extracting a causal graph via a PC-based causal discovery and then training a graph-conditioned WGAN-GP with sub-generators. A reinforcement-learning-based causal loss uses a SHD-based reward to align the causal structure of synthetic data with the real data, optimizing through a policy-gradient that leverages log-likelihoods of mixed-type outputs. Across 14 datasets, CA-GAN achieves superior causal preservation, downstream utility, and privacy protection compared to six baselines, with competitive computation times. The approach enables practical, privacy-preserving synthetic data generation that maintains reliable causal inferences for enterprise analytics and secure research.

Abstract

The utility of tabular data for tasks ranging from model training to large-scale data analysis is often constrained by privacy concerns or regulatory hurdles. While existing data generation methods, particularly those based on Generative Adversarial Networks (GANs), have shown promise, they frequently struggle with capturing complex causal relationship, maintaining data utility, and providing provable privacy guarantees suitable for enterprise deployment. We introduce CA-GAN, a novel generative framework specifically engineered to address these challenges for real-world tabular datasets. CA-GAN utilizes a two-step approach: causal graph extraction to learn a robust, comprehensive causal relationship in the data's manifold, followed by a custom Conditional WGAN-GP (Wasserstein GAN with Gradient Penalty) that operates exclusively as per the structure of nodes in the causal graph. More importantly, the generator is trained with a new Reinforcement Learning-based objective that aligns the causal graphs constructed from real and fake data, ensuring the causal awareness in both training and sampling phases. We demonstrate CA-GAN superiority over six SOTA methods across 14 tabular datasets. Our evaluations, focused on core data engineering metrics: causal preservation, utility preservation, and privacy preservation. Our method offers a practical, high-performance solution for data engineers seeking to create high-quality, privacy-compliant synthetic datasets to benchmark database systems, accelerate software development, and facilitate secure data-driven research.

Paper Structure

This paper contains 31 sections, 11 equations, 8 figures, 6 tables, 1 algorithm.

Figures (8)

  • Figure 1: Comparison between real data sharing and synthetic data sharing. Sharing synthetic data enables secure collaborations by preserving both utility and privacy whereas sharing real data risks information leakage.
  • Figure 2: Our framework CA-GAN. First, we use PC (a causal discovery algorithm) to extract the causal graph ${\cal G}_{real}$ from the real dataset ${\cal D}_{real}$. Next, we train a WGAN-GP model with $M$ sub-generators and one discriminator. Each sub-generator $G_{j}$ synthesizes variable $X_{j}$ sequentially following the node ordering in $\mathcal{G}_{real}$. The input of $G_{j}$ includes parent values $Pa(X_{j})$ and noise $z_{j}$. We train the discriminator with a standard discriminator loss using fake samples $\hat{x}$ and real samples $x$ to match two distributions $p({\cal D}_{fake})$ and $p({\cal D}_{real})$. Finally, we design a reward $R(\hat{x})$ based on Structural Hamming Distance (SHD) between ${\cal G}_{real}$ and ${\cal G}_{fake}$ extracted from $\hat{x}$. Using the Policy Gradient theorem, we construct the causal loss that is the product of the log-probability of the generator output $\log p(\hat{x})$ and the reward $R(\hat{x})$. We train the generator with the final loss = adversarial loss + causal loss, which unifies adversarial and causal learning and ensures both statistical realism and causal fidelity.
  • Figure 3: True causal graphs (DAGs) used to construct synthetic benchmark datasets.
  • Figure 4: DCR distributions of tabular generation methods on the dataset bank. A very small DCR suggests that the method may copy certain feature values from the original data, leading to information leakage. In contrast, a very high DCR indicates that the generated samples might be outliers and unrealistic. Our method CA-GAN balances realism and privacy better than other methods.
  • Figure 5: Average risk score of re-identification attacks across eight real-world tabular datasets. $\downarrow$ means “lower is better” .
  • ...and 3 more figures