Table of Contents
Fetching ...

BECAUSE: Bilinear Causal Representation for Generalizable Offline Model-based Reinforcement Learning

Haohong Lin, Wenhao Ding, Jian Chen, Laixi Shi, Jiacheng Zhu, Bo Li, Ding Zhao

TL;DR

This paper tackles objective mismatch in offline model-based RL by identifying confounders in offline data as a main source of distribution shifts. It introduces BECAUSE, a causal-representation framework built on ASC-MDP and bilinear MDPs, to learn sparsified, confounder-aware representations $\phi$ and $\mu$, along with a core matrix $M$, enabling unconfounded world modeling and uncertainty-aware planning. BECAUSE combines regularized MLE-based causal mask discovery with an uncertainty-quantified planning module via an Energy-Based Model, and provides finite-sample guarantees on suboptimality. Empirically, BECAUSE demonstrates strong generalization and robustness across 18 tasks with varying data quality and environment contexts, outperforming a range of offline RL baselines and displaying resilience to increasing confounding factors.

Abstract

Offline model-based reinforcement learning (MBRL) enhances data efficiency by utilizing pre-collected datasets to learn models and policies, especially in scenarios where exploration is costly or infeasible. Nevertheless, its performance often suffers from the objective mismatch between model and policy learning, resulting in inferior performance despite accurate model predictions. This paper first identifies the primary source of this mismatch comes from the underlying confounders present in offline data for MBRL. Subsequently, we introduce \textbf{B}ilin\textbf{E}ar \textbf{CAUS}al r\textbf{E}presentation~(BECAUSE), an algorithm to capture causal representation for both states and actions to reduce the influence of the distribution shift, thus mitigating the objective mismatch problem. Comprehensive evaluations on 18 tasks that vary in data quality and environment context demonstrate the superior performance of BECAUSE over existing offline RL algorithms. We show the generalizability and robustness of BECAUSE under fewer samples or larger numbers of confounders. Additionally, we offer theoretical analysis of BECAUSE to prove its error bound and sample efficiency when integrating causal representation into offline MBRL.

BECAUSE: Bilinear Causal Representation for Generalizable Offline Model-based Reinforcement Learning

TL;DR

This paper tackles objective mismatch in offline model-based RL by identifying confounders in offline data as a main source of distribution shifts. It introduces BECAUSE, a causal-representation framework built on ASC-MDP and bilinear MDPs, to learn sparsified, confounder-aware representations and , along with a core matrix , enabling unconfounded world modeling and uncertainty-aware planning. BECAUSE combines regularized MLE-based causal mask discovery with an uncertainty-quantified planning module via an Energy-Based Model, and provides finite-sample guarantees on suboptimality. Empirically, BECAUSE demonstrates strong generalization and robustness across 18 tasks with varying data quality and environment contexts, outperforming a range of offline RL baselines and displaying resilience to increasing confounding factors.

Abstract

Offline model-based reinforcement learning (MBRL) enhances data efficiency by utilizing pre-collected datasets to learn models and policies, especially in scenarios where exploration is costly or infeasible. Nevertheless, its performance often suffers from the objective mismatch between model and policy learning, resulting in inferior performance despite accurate model predictions. This paper first identifies the primary source of this mismatch comes from the underlying confounders present in offline data for MBRL. Subsequently, we introduce \textbf{B}ilin\textbf{E}ar \textbf{CAUS}al r\textbf{E}presentation~(BECAUSE), an algorithm to capture causal representation for both states and actions to reduce the influence of the distribution shift, thus mitigating the objective mismatch problem. Comprehensive evaluations on 18 tasks that vary in data quality and environment context demonstrate the superior performance of BECAUSE over existing offline RL algorithms. We show the generalizability and robustness of BECAUSE under fewer samples or larger numbers of confounders. Additionally, we offer theoretical analysis of BECAUSE to prove its error bound and sample efficiency when integrating causal representation into offline MBRL.
Paper Structure (49 sections, 5 theorems, 67 equations, 9 figures, 13 tables, 1 algorithm)

This paper contains 49 sections, 5 theorems, 67 equations, 9 figures, 13 tables, 1 algorithm.

Key Result

Theorem 1

Consider any $0<\delta <1$ and any initial state $\widetilde{s} \in \mathcal{S}$. Under the Assumption assum:existence, assum:feature and that the transition model $T$ is an SCM (defined in def:scm), for any accuracy level $0\leq \xi\leq 1$, with probability at least $1-\delta$, the output policy $ where $C_1, C_s$ are some universal constants, $\sigma$ is SCM's noise level (see def:scm), and $M\

Figures (9)

  • Figure 1: The objective mismatch problem.
  • Figure 2: Comparison of our ASC-MDP with two existing formulations.
  • Figure 3: BECAUSE learns a causality-aware representation from the buffer and uses it in both the world model and uncertainty quantification to obtain a pessimistic planning policy.
  • Figure 4: Three environments used in this paper.
  • Figure 5: Results of BECAUSE and baselines in different tasks. (a) Average success rate in distribution and out of distribution. (b) Average success rate w.r.t. ratio of offline samples. (c) Average success rate w.r.t. spurious level in the environments. We evaluate the mean and standard deviation of the best performance among 10 random seeds and report task-wise results in Appendix Table \ref{['tab:overall']}.
  • ...and 4 more figures

Theorems & Definitions (11)

  • Definition 1: Bilinear MDP yang2020reinforcement
  • Definition 2: ASC-MDP
  • Definition 3: Construction of causal graph $G$
  • Remark 1
  • Theorem 1: Performance guarantee
  • Definition 4: Structured Causal Model
  • Definition 5: $\delta$-Uncertainty Quantifier
  • Lemma 1: Decomposition of Suboptimality jin2021pessimism
  • Lemma 2: Suboptimality in standard MDP jin2021pessimism
  • Lemma 3: Uncertainty bound for Bilinear Causal Representation
  • ...and 1 more