Table of Contents
Fetching ...

Causal Action Influence Aware Counterfactual Data Augmentation

Núria Armengol Urpí, Marco Bagatella, Marin Vlastelica, Georg Martius

TL;DR

The paper addresses the challenge of causal confusion in offline reinforcement learning by introducing CAIAC, a data augmentation method that constructs feasible counterfactual transitions by swapping action-uninfluenced state factors across trajectories. It relies on a state-conditioned mutual-information-based CAI score, $C^j(s) = I(S'_j; A \mid S=s)$, estimated with a Gaussian transition model to identify uncontrollable factors and infer a local causal factorization. By exchanging uncontrollable components between trajectories that share similar local structures, CAIAC enlarges the joint state-space support and mitigates distributional shift without requiring online interaction or counterfactual rollouts. Empirically, CAIAC improves robustness and data efficiency on offline goal-conditioned tasks in Franka-Kitchen and Fetch-Push/Pick&Lift, outperforming heuristic baselines like CoDA and RSC and demonstrating feasible, beneficial counterfactuals that reduce causal confusion and enhance generalization.

Abstract

Offline data are both valuable and practical resources for teaching robots complex behaviors. Ideally, learning agents should not be constrained by the scarcity of available demonstrations, but rather generalize beyond the training distribution. However, the complexity of real-world scenarios typically requires huge amounts of data to prevent neural network policies from picking up on spurious correlations and learning non-causal relationships. We propose CAIAC, a data augmentation method that can create feasible synthetic transitions from a fixed dataset without having access to online environment interactions. By utilizing principled methods for quantifying causal influence, we are able to perform counterfactual reasoning by swapping $\it{action}$-unaffected parts of the state-space between independent trajectories in the dataset. We empirically show that this leads to a substantial increase in robustness of offline learning algorithms against distributional shift.

Causal Action Influence Aware Counterfactual Data Augmentation

TL;DR

The paper addresses the challenge of causal confusion in offline reinforcement learning by introducing CAIAC, a data augmentation method that constructs feasible counterfactual transitions by swapping action-uninfluenced state factors across trajectories. It relies on a state-conditioned mutual-information-based CAI score, , estimated with a Gaussian transition model to identify uncontrollable factors and infer a local causal factorization. By exchanging uncontrollable components between trajectories that share similar local structures, CAIAC enlarges the joint state-space support and mitigates distributional shift without requiring online interaction or counterfactual rollouts. Empirically, CAIAC improves robustness and data efficiency on offline goal-conditioned tasks in Franka-Kitchen and Fetch-Push/Pick&Lift, outperforming heuristic baselines like CoDA and RSC and demonstrating feasible, beneficial counterfactuals that reduce causal confusion and enhance generalization.

Abstract

Offline data are both valuable and practical resources for teaching robots complex behaviors. Ideally, learning agents should not be constrained by the scarcity of available demonstrations, but rather generalize beyond the training distribution. However, the complexity of real-world scenarios typically requires huge amounts of data to prevent neural network policies from picking up on spurious correlations and learning non-causal relationships. We propose CAIAC, a data augmentation method that can create feasible synthetic transitions from a fixed dataset without having access to online environment interactions. By utilizing principled methods for quantifying causal influence, we are able to perform counterfactual reasoning by swapping -unaffected parts of the state-space between independent trajectories in the dataset. We empirically show that this leads to a substantial increase in robustness of offline learning algorithms against distributional shift.
Paper Structure (42 sections, 4 equations, 14 figures, 3 tables, 1 algorithm)

This paper contains 42 sections, 4 equations, 14 figures, 3 tables, 1 algorithm.

Figures (14)

  • Figure 1: Overview of the proposed approach. Interactions between the agent and entities in the world are sparse. We use causal action influence (CAI), a local causal measure, to determine action-independent entities and create counterfactual data by swapping states of these entities from other observations in the dataset. Offline learning with these augmentations leads to better generalization.
  • Figure 1: CAIAC
  • Figure 2: CAIAC counterfactual samples are consistent with the environment's dynamics and increase the support of the joint state space distribution, enabling the agent to be robust to distributional shift. Left: Log-likelihoods under the environment transition kernel of counterfactuals created with different methods. Right: Original data and counterfactuals augmentations with CAIAC visualized with t-SNE. Details on this evaluation are reported in Appendix \ref{['app:test_quality_cf']}.
  • Figure 3: Illustration of counterfactual data augmentation. The global causal graph does not allow for factorization (a). Our local causal graph (b) is pruned by causal action influence. Object-object interactions are assumed to be rare/not existing (gray dashed). We swap elements that are not under control (i.e. in set $\mathcal{U}$) by samples from the data, thus creating counterfactual samples. We omit the exogenous variables from the global graph for compactness.
  • Figure 4: Motivating Franka-Kitchen example. The experimental setup (left) and success rates for in-distribution and out-of-distribution tasks (right). Metrics are averaged over 10 seeds and 10 episodes per task, with 95% simple bootstrap confidence intervals.
  • ...and 9 more figures

Theorems & Definitions (5)

  • Definition 2.1: Markov Decision Process (MDP)
  • Definition 2.2: SCM pearl2009causality
  • Definition 2.3: $\textit{do}$-intervention pearl2009causality
  • Definition 2.4: Counterfactual
  • Definition 3.1: Local Causal Model Pitis2020