Offline Imitation from Observation via Primal Wasserstein State Occupancy Matching
Kai Yan, Alexander G. Schwing, Yu-xiong Wang
TL;DR
This work tackles offline Learning from Observations (LfO) by introducing PW-DICE, a novel method that minimizes the primal Wasserstein distance between the learner’s and expert state occupancies using a contrastively learned distance metric. By integrating KL-based pessimistic regularizers, PW-DICE yields a single-level convex optimization whose dual variables enable weighted behavior cloning to recover the policy, and it recovers SMODICE as a special case. Empirically, PW-DICE outperforms state-of-the-art DICE-based methods and other Wasserstein approaches on tabular and MuJoCo benchmarks, demonstrating the importance of the distance metric and robustness to distorted dynamics. The approach unifies $f$-divergence and Wasserstein minimization within a single framework and provides practical improvements for offline LfO tasks with limited expert data and diverse environments.
Abstract
In real-world scenarios, arbitrary interactions with the environment can often be costly, and actions of expert demonstrations are not always available. To reduce the need for both, offline Learning from Observations (LfO) is extensively studied: the agent learns to solve a task given only expert states and task-agnostic non-expert state-action pairs. The state-of-the-art DIstribution Correction Estimation (DICE) methods, as exemplified by SMODICE, minimize the state occupancy divergence between the learner's and empirical expert policies. However, such methods are limited to either $f$-divergences (KL and $chi^2$) or Wasserstein distance with Rubinstein duality, the latter of which constrains the underlying distance metric crucial to the performance of Wasserstein-based solutions. To enable more flexible distance metrics, we propose Primal Wasserstein DICE (PW-DICE). It minimizes the primal Wasserstein distance between the learner and expert state occupancies and leverages a contrastively learned distance metric. Theoretically, our framework is a generalization of SMODICE, and is the first work that unifies $f$-divergence and Wasserstein minimization. Empirically, we find that PW-DICE improves upon several state-of-the-art methods. The code is available at https://github.com/KaiYan289/PW-DICE.
