Automated Discovery of Functional Actual Causes in Complex Environments
Caleb Chuck, Sankaran Vaidyanathan, Stephen Giguere, Amy Zhang, David Jensen, Scott Niekum
TL;DR
This work tackles the challenge of causalgeneralization in reinforcement learning by introducing Functional Actual Cause (FAC), a principled framework that constrains actual causes using invariant preimages (IVP) to capture context-specific independencies and normalize causation judgments. It then presents Joint Optimization for Actual Cause Inference (JACI), a neural approach that jointly learns a state-to-cause mapping and a masked forward model to recover functional actual causes from observational data in continuous, high-dimensional environments. The approach is shown to align with established causality intuitions on classic examples and to outperform baselines in synthetic Random Vector domains and RL-like tasks such as Mini-Breakout and 2D Pushing, demonstrating scalable, accurate identification of sparse, context-relevant causes. Collectively, FAC and JACI offer a scalable bridge between formal actual causality and practical inference for RL, enabling improved world-modeling, explanations, and exploration in complex environments.
Abstract
Reinforcement learning (RL) algorithms often struggle to learn policies that generalize to novel situations due to issues such as causal confusion, overfitting to irrelevant factors, and failure to isolate control of state factors. These issues stem from a common source: a failure to accurately identify and exploit state-specific causal relationships in the environment. While some prior works in RL aim to identify these relationships explicitly, they rely on informal domain-specific heuristics such as spatial and temporal proximity. Actual causality offers a principled and general framework for determining the causes of particular events. However, existing definitions of actual cause often attribute causality to a large number of events, even if many of them rarely influence the outcome. Prior work on actual causality proposes normality as a solution to this problem, but its existing implementations are challenging to scale to complex and continuous-valued RL environments. This paper introduces functional actual cause (FAC), a framework that uses context-specific independencies in the environment to restrict the set of actual causes. We additionally introduce Joint Optimization for Actual Cause Inference (JACI), an algorithm that learns from observational data to infer functional actual causes. We demonstrate empirically that FAC agrees with known results on a suite of examples from the actual causality literature, and JACI identifies actual causes with significantly higher accuracy than existing heuristic methods in a set of complex, continuous-valued environments.
