A Tractable Inference Perspective of Offline RL
Xuejie Liu, Anji Liu, Guy Van den Broeck, Yitao Liang
TL;DR
This work addresses offline RL by questioning the primacy of expressive sequence models and highlighting the role of tractable inference in achieving high returns. It introduces Trifle, a framework that combines Tractable Probabilistic Models with traditional sequence models to compute exact marginals and conditioning probabilities, enabling high-return action sampling even in multi-step, stochastic, or constrained settings. The approach uses per-dimension TPM-corrected sampling, beam search, and adaptive thresholds to bias actions toward high expected returns while staying within the offline data distribution, yielding state-of-the-art results on 7 of 9 Gym-MuJoCo benchmarks and strong performance in stochastic and safe RL tasks. The empirical results demonstrate that tractability can substantially improve inference-time optimality and overall performance, offering a practical path toward inference-aware offline RL; limitations include TPM accuracy dependency and computational considerations. Overall, Trifle advances offline RL by foregrounding tractable probabilistic reasoning as a key component of effective inference-time decision making.
Abstract
A popular paradigm for offline Reinforcement Learning (RL) tasks is to first fit the offline trajectories to a sequence model, and then prompt the model for actions that lead to high expected return. In addition to obtaining accurate sequence models, this paper highlights that tractability, the ability to exactly and efficiently answer various probabilistic queries, plays an important role in offline RL. Specifically, due to the fundamental stochasticity from the offline data-collection policies and the environment dynamics, highly non-trivial conditional/constrained generation is required to elicit rewarding actions. it is still possible to approximate such queries, we observe that such crude estimates significantly undermine the benefits brought by expressive sequence models. To overcome this problem, this paper proposes Trifle (Tractable Inference for Offline RL), which leverages modern Tractable Probabilistic Models (TPMs) to bridge the gap between good sequence models and high expected returns at evaluation time. Empirically, Trifle achieves the most state-of-the-art scores in 9 Gym-MuJoCo benchmarks against strong baselines. Further, owing to its tractability, Trifle significantly outperforms prior approaches in stochastic environments and safe RL tasks (e.g. with action constraints) with minimum algorithmic modifications.
