Table of Contents
Fetching ...

Federated Causal Representation Learning in State-Space Systems for Decentralized Counterfactual Reasoning

Nazal Mohamed, Ayush Mohanty, Nagi Gebraeel

TL;DR

A federated framework for causal representation learning in state-space systems that captures interdependencies among clients under these constraints is proposed and proves convergence to a centralized oracle and provides privacy guarantees.

Abstract

Networks of interdependent industrial assets (clients) are tightly coupled through physical processes and control inputs, raising a key question: how would the output of one client change if another client were operated differently? This is difficult to answer because client-specific data are high-dimensional and private, making centralization of raw data infeasible. Each client also maintains proprietary local models that cannot be modified. We propose a federated framework for causal representation learning in state-space systems that captures interdependencies among clients under these constraints. Each client maps high-dimensional observations into low-dimensional latent states that disentangle intrinsic dynamics from control-driven influences. A central server estimates the global state-transition and control structure. This enables decentralized counterfactual reasoning where clients predict how outputs would change under alternative control inputs at others while only exchanging compact latent states. We prove convergence to a centralized oracle and provide privacy guarantees. Our experiments demonstrate scalability, and accurate cross-client counterfactual inference on synthetic and real-world industrial control system datasets.

Federated Causal Representation Learning in State-Space Systems for Decentralized Counterfactual Reasoning

TL;DR

A federated framework for causal representation learning in state-space systems that captures interdependencies among clients under these constraints is proposed and proves convergence to a centralized oracle and provides privacy guarantees.

Abstract

Networks of interdependent industrial assets (clients) are tightly coupled through physical processes and control inputs, raising a key question: how would the output of one client change if another client were operated differently? This is difficult to answer because client-specific data are high-dimensional and private, making centralization of raw data infeasible. Each client also maintains proprietary local models that cannot be modified. We propose a federated framework for causal representation learning in state-space systems that captures interdependencies among clients under these constraints. Each client maps high-dimensional observations into low-dimensional latent states that disentangle intrinsic dynamics from control-driven influences. A central server estimates the global state-transition and control structure. This enables decentralized counterfactual reasoning where clients predict how outputs would change under alternative control inputs at others while only exchanging compact latent states. We prove convergence to a centralized oracle and provide privacy guarantees. Our experiments demonstrate scalability, and accurate cross-client counterfactual inference on synthetic and real-world industrial control system datasets.
Paper Structure (52 sections, 12 theorems, 106 equations, 6 figures, 3 tables, 1 algorithm)

This paper contains 52 sections, 12 theorems, 106 equations, 6 figures, 3 tables, 1 algorithm.

Key Result

Proposition 4.1

Let the input to an LTI system be $u^{t-1} = u_0$. Then the ATE on measurements $y^t$ under the intervention $\mathrm{do}(u^{t-1} = u_1)$ equals $CB\,(u_1 - u_0)$.

Figures (6)

  • Figure 1: A simplified view of the computation inside the client $m$
  • Figure 2: Global loss at the server and losses at the client from different models (augmented client loss: $L_{m,a}^k$, proprietary client loss: $\frac{1}{T} \sum_{t = 1}^T \norm{r^t_{m,c}}_2^2$ are plotted against the number of iterations ($k$).
  • Figure 3: Disentanglement penalty at the server and $\delta d$ at the clients vs Number of iterations
  • Figure 4: Comparison between Penalty Method (PM) and Augmented Lagrangian (AL) for the two client system. $L_s$, $\mathcal{D}$, Local loss (client residual) $L_{m,a}$ and $\delta d_m$ vs Number of iterations ($k$) are plotted. Note that we defined $\delta d_m := \norm{\phi_m - \frac{1}{T}\sum_{t = 1}^T \left(\sum_{n \neq m}\hat{B}_{mn} u_n^{t-1} \right)}_2$
  • Figure 5: HAI dataset
  • ...and 1 more figures

Theorems & Definitions (32)

  • Proposition 4.1
  • Claim 5.3
  • Claim 5.5
  • Theorem 6.1
  • Corollary 6.2
  • Lemma 7.2: Normal equations
  • Theorem 7.3: Convergence to the oracle
  • Definition A.1: Local dataset and horizon
  • Definition A.2: Neighboring datasets
  • Definition A.5: Message generation map
  • ...and 22 more