Towards Causal Model-Based Policy Optimization
Alberto Caron, Vasilios Mavroudis, Chris Hicks
TL;DR
This work addresses brittleness of traditional MBRL under distribution shifts caused by spurious correlations by introducing C-MBPO, which augments standard RL with a Causal Markov Decision Process (C-MDP) and a local Structural Causal Model (SCM). It learns causal graphs over state and reward dynamics, uses counterfactual rollouts to guide policy optimization in a Dyna-Q/MBPO-inspired framework, and demonstrates robustness to both near and far distribution shifts while offering improved interpretability. The key contributions include formalizing the C-MDP, proposing a causal learning and planning pipeline, and empirically showing that policies derived from C-MBPO are more robust to causal-structure changes than non-causal baselines. This causally grounded approach has practical implications for deploying RL in non-stationary environments where spurious correlations can mislead policy learning.
Abstract
Real-world decision-making problems are often marked by complex, uncertain dynamics that can shift or break under changing conditions. Traditional Model-Based Reinforcement Learning (MBRL) approaches learn predictive models of environment dynamics from queried trajectories and then use these models to simulate rollouts for policy optimization. However, such methods do not account for the underlying causal mechanisms that govern the environment, and thus inadvertently capture spurious correlations, making them sensitive to distributional shifts and limiting their ability to generalize. The same naturally holds for model-free approaches. In this work, we introduce Causal Model-Based Policy Optimization (C-MBPO), a novel framework that integrates causal learning into the MBRL pipeline to achieve more robust, explainable, and generalizable policy learning algorithms. Our approach centers on first inferring a Causal Markov Decision Process (C-MDP) by learning a local Structural Causal Model (SCM) of both the state and reward transition dynamics from trajectories gathered online. C-MDPs differ from classic MDPs in that we can decompose causal dependencies in the environment dynamics via specifying an associated Causal Bayesian Network. C-MDPs allow for targeted interventions and counterfactual reasoning, enabling the agent to distinguish between mere statistical correlations and causal relationships. The learned SCM is then used to simulate counterfactual on-policy transitions and rewards under hypothetical actions (or ``interventions"), thereby guiding policy optimization more effectively. The resulting policy learned by C-MBPO can be shown to be robust to a class of distributional shifts that affect spurious, non-causal relationships in the dynamics. We demonstrate this through some simple experiments involving near and far OOD dynamics drifts.
