Table of Contents
Fetching ...

DIGIC: Domain Generalizable Imitation Learning by Causal Discovery

Yang Chen, Yitao Liang, Zhouchen Lin

TL;DR

The paper tackles domain generalization in imitation learning by proposing DIGIC, a two-stage framework that first discovers the direct causes of the expert action from the demonstration data distribution using a causal-discovery module, then trains an imitation policy on these causal features. By conditioning on the direct causes, the BC policy achieves domain generalization across unseen environments without requiring multi-domain data, and the method can complement cross-domain variation-based approaches under mild non-structural assumptions. The authors implement a learning-based generalized inverse-covariance approach to identify causal features and validate DIGIC on OpenAI Gym control tasks, where it shows strong performance in shifted domains and improves invariant-spurious-feature robustness when paired with multi-domain methods like IRM. Overall, DIGIC provides a practical and flexible pathway to robust imitation policies grounded in causal structure derived from demonstrations, reducing reliance on cross-domain data and expanding the applicability of domain-generalization in imitation learning.

Abstract

Causality has been combined with machine learning to produce robust representations for domain generalization. Most existing methods of this type require massive data from multiple domains to identify causal features by cross-domain variations, which can be expensive or even infeasible and may lead to misidentification in some cases. In this work, we make a different attempt by leveraging the demonstration data distribution to discover the causal features for a domain generalizable policy. We design a novel framework, called DIGIC, to identify the causal features by finding the direct cause of the expert action from the demonstration data distribution via causal discovery. Our framework can achieve domain generalizable imitation learning with only single-domain data and serve as a complement for cross-domain variation-based methods under non-structural assumptions on the underlying causal models. Our empirical study in various control tasks shows that the proposed framework evidently improves the domain generalization performance and has comparable performance to the expert in the original domain simultaneously.

DIGIC: Domain Generalizable Imitation Learning by Causal Discovery

TL;DR

The paper tackles domain generalization in imitation learning by proposing DIGIC, a two-stage framework that first discovers the direct causes of the expert action from the demonstration data distribution using a causal-discovery module, then trains an imitation policy on these causal features. By conditioning on the direct causes, the BC policy achieves domain generalization across unseen environments without requiring multi-domain data, and the method can complement cross-domain variation-based approaches under mild non-structural assumptions. The authors implement a learning-based generalized inverse-covariance approach to identify causal features and validate DIGIC on OpenAI Gym control tasks, where it shows strong performance in shifted domains and improves invariant-spurious-feature robustness when paired with multi-domain methods like IRM. Overall, DIGIC provides a practical and flexible pathway to robust imitation policies grounded in causal structure derived from demonstrations, reducing reliance on cross-domain data and expanding the applicability of domain-generalization in imitation learning.

Abstract

Causality has been combined with machine learning to produce robust representations for domain generalization. Most existing methods of this type require massive data from multiple domains to identify causal features by cross-domain variations, which can be expensive or even infeasible and may lead to misidentification in some cases. In this work, we make a different attempt by leveraging the demonstration data distribution to discover the causal features for a domain generalizable policy. We design a novel framework, called DIGIC, to identify the causal features by finding the direct cause of the expert action from the demonstration data distribution via causal discovery. Our framework can achieve domain generalizable imitation learning with only single-domain data and serve as a complement for cross-domain variation-based methods under non-structural assumptions on the underlying causal models. Our empirical study in various control tasks shows that the proposed framework evidently improves the domain generalization performance and has comparable performance to the expert in the original domain simultaneously.
Paper Structure (14 sections, 1 theorem, 5 equations, 4 figures, 1 table)

This paper contains 14 sections, 1 theorem, 5 equations, 4 figures, 1 table.

Key Result

Theorem 1

Suppose that $\bm X \subset \bm O$ is the direct cause of the expert action $\bm A$ in the original SCM ${\mathcal{M}}_0$. If Assumption as:ids and Assumption as:coverage hold, then the behavior cloning policy (at the population level) is a domain generalizable policy with respect to ${\mathbb{M}}$.

Figures (4)

  • Figure 1: Various principles to identify causal features. In (a), distribution-based methods leverage the conditional independence relations (i.e., the stopping action is independent of the city conditioned on the traffic light) in the demonstration data distribution to identify that the traffic light is the direct cause of the stopping action. In (b), cross-domain variation-based methods identify the traffic light as the causal feature because it is invariant across domains.
  • Figure 2: Overview of our DIGIC framework implementation. The training stage is comprised of two key modules: Causal Discovery and Imitation Learning. In Causal Discovery, the moralized structure is estimated using a generalized inverse based on the observations and the action, and the mask of $\operatorname{PA}(\bm A)$ is extracted. The Hadamard product of $\bm O$ and this mask reveals causal features, which are employed by the policy net in the Imitation Learning module to derive a domain generalizable policy. During the inference stage, the causal discovery is bypassed, and the previously discovered mask of $\operatorname{PA}(\bm A)$ is applied to the observations $\bm O$ with the resulting causal features fed into the policy model for decision-making.
  • Figure 3: Evaluation results of single-domain generalization in different tasks in the original and shifted domains. Each curve shows the mean and the standard deviation of the average total rewards over 10 trials. In each subfigure, the left part is the evaluation results in the original domain, and the right part is the evaluation results in the shifted domain. The $x$-axis indicates the training epoch at which the checkpoint is saved. The $y$-axis indicates the average total reward. We evaluate the checkpoints at different epochs. The legend is shared by all the subfigures. We see that DIGIC achieves comparable performance with the expert in most tasks when evaluated in the original domains and outperform BC and CCIL evidently in the shifted domains.
  • Figure 4: Comparison of domain generalization performance with the invariant spurious features in two tasks. Each curve represents the mean and standard deviation of average total rewards across 10 trials, with the $x$-axis showing the training epoch and the $y$-axis indicating the average total reward. Checkpoints are evaluated at different epochs. The legend is shared by all the subfigures. The result reveals that IRM-DIGIC outperforms IRM evidently in the presence of the invariant spurious features.

Theorems & Definitions (3)

  • Example 1
  • Definition 1: Domain Generalizable Imitation Policy
  • Theorem 1