Table of Contents
Fetching ...

Mutual Information Regularized Offline Reinforcement Learning

Xiao Ma, Bingyi Kang, Zhongwen Xu, Min Lin, Shuicheng Yan

TL;DR

MISA is a general offline RL framework that unifies conservative Q-learning (CQL) and behavior regularization methods (e.g., TD3+BC) as special cases and shows that MISA performs significantly better than existing methods and achieves new state-of-the-art on various tasks of the D4RL benchmark.

Abstract

The major challenge of offline RL is the distribution shift that appears when out-of-distribution actions are queried, which makes the policy improvement direction biased by extrapolation errors. Most existing methods address this problem by penalizing the policy or value for deviating from the behavior policy during policy improvement or evaluation. In this work, we propose a novel MISA framework to approach offline RL from the perspective of Mutual Information between States and Actions in the dataset by directly constraining the policy improvement direction. MISA constructs lower bounds of mutual information parameterized by the policy and Q-values. We show that optimizing this lower bound is equivalent to maximizing the likelihood of a one-step improved policy on the offline dataset. Hence, we constrain the policy improvement direction to lie in the data manifold. The resulting algorithm simultaneously augments the policy evaluation and improvement by adding mutual information regularizations. MISA is a general framework that unifies conservative Q-learning (CQL) and behavior regularization methods (e.g., TD3+BC) as special cases. We introduce 3 different variants of MISA, and empirically demonstrate that tighter mutual information lower bound gives better offline RL performance. In addition, our extensive experiments show MISA significantly outperforms a wide range of baselines on various tasks of the D4RL benchmark,e.g., achieving 742.9 total points on gym-locomotion tasks. Our code is available at https://github.com/sail-sg/MISA.

Mutual Information Regularized Offline Reinforcement Learning

TL;DR

MISA is a general offline RL framework that unifies conservative Q-learning (CQL) and behavior regularization methods (e.g., TD3+BC) as special cases and shows that MISA performs significantly better than existing methods and achieves new state-of-the-art on various tasks of the D4RL benchmark.

Abstract

The major challenge of offline RL is the distribution shift that appears when out-of-distribution actions are queried, which makes the policy improvement direction biased by extrapolation errors. Most existing methods address this problem by penalizing the policy or value for deviating from the behavior policy during policy improvement or evaluation. In this work, we propose a novel MISA framework to approach offline RL from the perspective of Mutual Information between States and Actions in the dataset by directly constraining the policy improvement direction. MISA constructs lower bounds of mutual information parameterized by the policy and Q-values. We show that optimizing this lower bound is equivalent to maximizing the likelihood of a one-step improved policy on the offline dataset. Hence, we constrain the policy improvement direction to lie in the data manifold. The resulting algorithm simultaneously augments the policy evaluation and improvement by adding mutual information regularizations. MISA is a general framework that unifies conservative Q-learning (CQL) and behavior regularization methods (e.g., TD3+BC) as special cases. We introduce 3 different variants of MISA, and empirically demonstrate that tighter mutual information lower bound gives better offline RL performance. In addition, our extensive experiments show MISA significantly outperforms a wide range of baselines on various tasks of the D4RL benchmark,e.g., achieving 742.9 total points on gym-locomotion tasks. Our code is available at https://github.com/sail-sg/MISA.
Paper Structure (25 sections, 3 theorems, 25 equations, 1 figure, 2 tables, 1 algorithm)

This paper contains 25 sections, 3 theorems, 25 equations, 1 figure, 2 tables, 1 algorithm.

Key Result

Lemma 3.1

The KL divergence admits the following lower bound: where the supremum is taken over a function family $\mathcal{F}$ satisfying the integrability constraints.

Figures (1)

  • Figure 1: tSNE of the Q-value network embeddings of walker2d-medium-v2 dataset, where red color denote high reward and blue color denote low reward.

Theorems & Definitions (3)

  • Lemma 3.1: $f$-divergence representation kl-f
  • Lemma 3.2: Donsker-Varadhan representation kl-dv
  • Theorem 4.1