Table of Contents
Fetching ...

Disentangling Dynamical Systems: Causal Representation Learning Meets Local Sparse Attention

Markus W. Baumgartner, Anson Lei, Joe Watson, Ingmar Posner

Abstract

Parametric system identification methods estimate the parameters of explicitly defined physical systems from data. Yet, they remain constrained by the need to provide an explicit function space, typically through a predefined library of candidate functions chosen via available domain knowledge. In contrast, deep learning can demonstrably model systems of broad complexity with high fidelity, but black-box function approximation typically fails to yield explicit descriptive or disentangled representations revealing the structure of a system. We develop a novel identifiability theorem, leveraging causal representation learning, to uncover disentangled representations of system parameters without structural assumptions. We derive a graphical criterion specifying when system parameters can be uniquely disentangled from raw trajectory data, up to permutation and diffeomorphism. Crucially, our analysis demonstrates that global causal structures provide a lower bound on the disentanglement guarantees achievable when considering local state-dependent causal structures. We instantiate system parameter identification as a variational inference problem, leveraging a sparsity-regularised transformer to uncover state-dependent causal structures. We empirically validate our approach across four synthetic domains, demonstrating its ability to recover highly disentangled representations that baselines fail to recover. Corroborating our theoretical analysis, our results confirm that enforcing local causal structure is often necessary for full identifiability.

Disentangling Dynamical Systems: Causal Representation Learning Meets Local Sparse Attention

Abstract

Parametric system identification methods estimate the parameters of explicitly defined physical systems from data. Yet, they remain constrained by the need to provide an explicit function space, typically through a predefined library of candidate functions chosen via available domain knowledge. In contrast, deep learning can demonstrably model systems of broad complexity with high fidelity, but black-box function approximation typically fails to yield explicit descriptive or disentangled representations revealing the structure of a system. We develop a novel identifiability theorem, leveraging causal representation learning, to uncover disentangled representations of system parameters without structural assumptions. We derive a graphical criterion specifying when system parameters can be uniquely disentangled from raw trajectory data, up to permutation and diffeomorphism. Crucially, our analysis demonstrates that global causal structures provide a lower bound on the disentanglement guarantees achievable when considering local state-dependent causal structures. We instantiate system parameter identification as a variational inference problem, leveraging a sparsity-regularised transformer to uncover state-dependent causal structures. We empirically validate our approach across four synthetic domains, demonstrating its ability to recover highly disentangled representations that baselines fail to recover. Corroborating our theoretical analysis, our results confirm that enforcing local causal structure is often necessary for full identifiability.
Paper Structure (55 sections, 7 theorems, 60 equations, 18 figures)

This paper contains 55 sections, 7 theorems, 60 equations, 18 figures.

Key Result

Theorem 1

Let $\mathcal{S}$, and $\hat{\mathcal{S}}$ correspond to two systems satisfying assumptions assumption:existence, assumption:obs_eq, and assumption:markov. Further assume that, Then the dynamical representations of $\mathcal{S}$, and $\hat{\mathcal{S}}$ are equivalent up to permutation and element-wise diffeomorphism if $\mathcal{G}$ satisfies the following graphical criterion: where $\text{Ch}_

Figures (18)

  • Figure 1: High-level overview of the developed theory. An observed trajectory (left) is encoded into a vector of latent system parameters (marked in dark blue). The developed theory shows that enforcing sparse causal relations between parameters and system components in the decoder (which performs one-step prediction) provably disentangles the system parameter representation.
  • Figure 2: Depiction of the ground-truth DAGs of the evaluation environments. The left and centre-left correspond to the dual particle and springs environments, respectively, and both satisfy the global graphical criterion for disentanglement. The centre-right and right graphs correspond to the local particle and bounce environments, respectively. Coloured arrows indicate causal edges that are only active in subsets of the state space, and only a limited number of subsets are shown for the bounce environment. The local particle and bounce environments satisfy only the local, not the global, graphical criterion for disentanglement.
  • Figure 3: Comparison of disentanglement across the test environments, where an MCC of $1.0$ represents perfect disentanglement. All trials are repeated over eight random seeds. Box plots display the minimum, lower quartile, upper quartile, and maximum values. The validation reconstruction loss is shown at the bottom, indicating that all models are approximately equiperformant in these environments. The VCD baseline, which learns static graphs, strongly disentangles in the first two environments, which satisfies the global graph criterion. In contrast, only SPARTAN, which learns state-dependent graphs, consistently disentangles in all environments.
  • Figure 4: Autoregressive rollouts of the dual particle environment produced from the learnt (a) SPARTAN, and (b) MLP models. Starting from the same initial state, the red particle experiences resistive damping forces independently in the horizontal and vertical axes, whilst the blue particle additionally experiences a spring force to the origin. The ground-truth system parameters are the damping coefficients and the spring constant. (a) shows a clear separation of parameters, increasing $\hat{\theta}_2$ evidently reduces spring strength. As the red particle is unaffected by the spring, it is invariant to changes in $\hat{\theta}_2$. Increasing $\hat{\theta}_1$increases horizontal damping, and increasing $\hat{\theta}_0$reduces vertical axis damping. In contrast, the entangled representations produced by the MLP yield no such insight.
  • Figure 5: Plots of the entanglement mapping for all models in the dual particle environment. Each subplot shows the marginal distribution of a learnt parameter (x-axis) against the ground-truth (y-axis). Representative trials were chosen. Disentanglement is exemplified by the clear element-wise diffeomorphic structure learnt by both the SPARTAN and VCD architectures, whereas the MLP and Transformer architectures learn entangled representations. Extended results for all environments and representative samples of the causal graphs learnt are provided in appendix \ref{['appendix:extended_results']}.
  • ...and 13 more figures

Theorems & Definitions (24)

  • Theorem 1: Disentanglement of Latent System Parameters
  • Definition 1: Identifiable Models khemakhem2020variational
  • Definition 2: Identifiable up to Equivalence
  • Definition 3: Diffeomorphism
  • Definition 4: Surjectivity
  • Definition 5: $C^k$ Functions
  • Definition 6: Equiv. up to Elementwise Perm. & Diffeomorphism lachapelle2022synergies
  • Definition 7: Disentanglement
  • Definition 8: Subsets of Graphs
  • Definition 9: Graph Arithmetic
  • ...and 14 more