Causal Action Influence Aware Counterfactual Data Augmentation
Núria Armengol Urpí, Marco Bagatella, Marin Vlastelica, Georg Martius
TL;DR
The paper addresses the challenge of causal confusion in offline reinforcement learning by introducing CAIAC, a data augmentation method that constructs feasible counterfactual transitions by swapping action-uninfluenced state factors across trajectories. It relies on a state-conditioned mutual-information-based CAI score, $C^j(s) = I(S'_j; A \mid S=s)$, estimated with a Gaussian transition model to identify uncontrollable factors and infer a local causal factorization. By exchanging uncontrollable components between trajectories that share similar local structures, CAIAC enlarges the joint state-space support and mitigates distributional shift without requiring online interaction or counterfactual rollouts. Empirically, CAIAC improves robustness and data efficiency on offline goal-conditioned tasks in Franka-Kitchen and Fetch-Push/Pick&Lift, outperforming heuristic baselines like CoDA and RSC and demonstrating feasible, beneficial counterfactuals that reduce causal confusion and enhance generalization.
Abstract
Offline data are both valuable and practical resources for teaching robots complex behaviors. Ideally, learning agents should not be constrained by the scarcity of available demonstrations, but rather generalize beyond the training distribution. However, the complexity of real-world scenarios typically requires huge amounts of data to prevent neural network policies from picking up on spurious correlations and learning non-causal relationships. We propose CAIAC, a data augmentation method that can create feasible synthetic transitions from a fixed dataset without having access to online environment interactions. By utilizing principled methods for quantifying causal influence, we are able to perform counterfactual reasoning by swapping $\it{action}$-unaffected parts of the state-space between independent trajectories in the dataset. We empirically show that this leads to a substantial increase in robustness of offline learning algorithms against distributional shift.
