Table of Contents
Fetching ...

Out-of-Distribution Adaptation in Offline RL: Counterfactual Reasoning via Causal Normalizing Flows

Minjae Cho, Jonathan P. How, Chuangchuang Sun

TL;DR

This work tackles offline RL by addressing distributional shift and the need to explore beyond the training data. It introduces MOOD-CRL, a model-based offline RL method that leverages Causal Normalizing Flows (CNF) to learn a world model with causal structure, enabling counterfactual reasoning for effective OOD adaptation without penalizing exploration. The approach integrates an autoregressive CNF with a base-space MLP to predict dynamics and rewards, guided by a physics-informed causal graph, and uses a log-likelihood based truncation to manage OOD predictions. Empirical results across discrete and continuous domains show MOOD-CRL can match or approach online performance, outperforming both model-based and model-free baselines, and demonstrating robust OOD handling in robotic control tasks.

Abstract

Despite notable successes of Reinforcement Learning (RL), the prevalent use of an online learning paradigm prevents its widespread adoption, especially in hazardous or costly scenarios. Offline RL has emerged as an alternative solution, learning from pre-collected static datasets. However, this offline learning introduces a new challenge known as distributional shift, degrading the performance when the policy is evaluated on scenarios that are Out-Of-Distribution (OOD) from the training dataset. Most existing offline RL resolves this issue by regularizing policy learning within the information supported by the given dataset. However, such regularization overlooks the potential for high-reward regions that may exist beyond the dataset. This motivates exploring novel offline learning techniques that can make improvements beyond the data support without compromising policy performance, potentially by learning causation (cause-and-effect) instead of correlation from the dataset. In this paper, we propose the MOOD-CRL (Model-based Offline OOD-Adapting Causal RL) algorithm, which aims to address the challenge of extrapolation for offline policy training through causal inference instead of policy-regularizing methods. Specifically, Causal Normalizing Flow (CNF) is developed to learn the transition and reward functions for data generation and augmentation in offline policy evaluation and training. Based on the data-invariant, physics-based qualitative causal graph and the observational data, we develop a novel learning scheme for CNF to learn the quantitative structural causal model. As a result, CNF gains predictive and counterfactual reasoning capabilities for sequential decision-making tasks, revealing a high potential for OOD adaptation. Our CNF-based offline RL approach is validated through empirical evaluations, outperforming model-free and model-based methods by a significant margin.

Out-of-Distribution Adaptation in Offline RL: Counterfactual Reasoning via Causal Normalizing Flows

TL;DR

This work tackles offline RL by addressing distributional shift and the need to explore beyond the training data. It introduces MOOD-CRL, a model-based offline RL method that leverages Causal Normalizing Flows (CNF) to learn a world model with causal structure, enabling counterfactual reasoning for effective OOD adaptation without penalizing exploration. The approach integrates an autoregressive CNF with a base-space MLP to predict dynamics and rewards, guided by a physics-informed causal graph, and uses a log-likelihood based truncation to manage OOD predictions. Empirical results across discrete and continuous domains show MOOD-CRL can match or approach online performance, outperforming both model-based and model-free baselines, and demonstrating robust OOD handling in robotic control tasks.

Abstract

Despite notable successes of Reinforcement Learning (RL), the prevalent use of an online learning paradigm prevents its widespread adoption, especially in hazardous or costly scenarios. Offline RL has emerged as an alternative solution, learning from pre-collected static datasets. However, this offline learning introduces a new challenge known as distributional shift, degrading the performance when the policy is evaluated on scenarios that are Out-Of-Distribution (OOD) from the training dataset. Most existing offline RL resolves this issue by regularizing policy learning within the information supported by the given dataset. However, such regularization overlooks the potential for high-reward regions that may exist beyond the dataset. This motivates exploring novel offline learning techniques that can make improvements beyond the data support without compromising policy performance, potentially by learning causation (cause-and-effect) instead of correlation from the dataset. In this paper, we propose the MOOD-CRL (Model-based Offline OOD-Adapting Causal RL) algorithm, which aims to address the challenge of extrapolation for offline policy training through causal inference instead of policy-regularizing methods. Specifically, Causal Normalizing Flow (CNF) is developed to learn the transition and reward functions for data generation and augmentation in offline policy evaluation and training. Based on the data-invariant, physics-based qualitative causal graph and the observational data, we develop a novel learning scheme for CNF to learn the quantitative structural causal model. As a result, CNF gains predictive and counterfactual reasoning capabilities for sequential decision-making tasks, revealing a high potential for OOD adaptation. Our CNF-based offline RL approach is validated through empirical evaluations, outperforming model-free and model-based methods by a significant margin.
Paper Structure (30 sections, 9 equations, 10 figures, 1 algorithm)

This paper contains 30 sections, 9 equations, 10 figures, 1 algorithm.

Figures (10)

  • Figure 1: Causation in MDP
  • Figure 2: Causation in detail
  • Figure 4: A depiction of Causal Reinforcement Learning employing a normalizing flow world model. The system handles two MDP tuples: the first is an arbitrary input for counterfactual reasoning, involving the use of state and action as inputs, initializing the next state with the current state for stability, and setting the reward to zero for initialization. The second tuple represents the original scenario. Subsequently, we calculate the loss between these tuples in the base space and train the mapping function.
  • Figure 5: Autoregressive flow model illustration of transformation: $F_k(\mathbf{z}_k|\tau, c^i) = \mathbf{z}_{k+1}$. The conditioner, denoted as $c$, gathers information from previous elements, while $\tau$ transforms the current element and its history into a new value. The figure is adopted from papamakarios2021normalizing.
  • Figure 6: A classical toy problem implemented in OpenAI Gymnasium involves a 15 $\times$ 15 grid with non-slippery conditions selected for a deterministic environment. The top-left corner has a state of 0, increasing by 1 towards the right. Any actions towards boundaries and obstacles result in staying in the same position.
  • ...and 5 more figures