Table of Contents
Fetching ...

Learning on One Mode: Addressing Multi-modality in Offline Reinforcement Learning

Mianchu Wang, Yue Jin, Giovanni Montana

TL;DR

Offline RL often struggles with distribution shift and multi-modal data. The paper proposes Learning on One Mode (LOM), which models the offline behaviour with a Gaussian Mixture and uses a hyper-Markov decision process and a hyper Q-function to select the most promising mode per state, followed by weighted imitation learning on that mode. The approach comes with policy-improvement guarantees and achieves state-of-the-art results on the D4RL benchmarks, especially in highly multi-modal regimes. By avoiding full multi-modal distribution modeling, LOM offers a simple yet powerful strategy for leveraging diverse offline data.

Abstract

Offline reinforcement learning (RL) seeks to learn optimal policies from static datasets without interacting with the environment. A common challenge is handling multi-modal action distributions, where multiple behaviours are represented in the data. Existing methods often assume unimodal behaviour policies, leading to suboptimal performance when this assumption is violated. We propose weighted imitation Learning on One Mode (LOM), a novel approach that focuses on learning from a single, promising mode of the behaviour policy. By using a Gaussian mixture model to identify modes and selecting the best mode based on expected returns, LOM avoids the pitfalls of averaging over conflicting actions. Theoretically, we show that LOM improves performance while maintaining simplicity in policy learning. Empirically, LOM outperforms existing methods on standard D4RL benchmarks and demonstrates its effectiveness in complex, multi-modal scenarios.

Learning on One Mode: Addressing Multi-modality in Offline Reinforcement Learning

TL;DR

Offline RL often struggles with distribution shift and multi-modal data. The paper proposes Learning on One Mode (LOM), which models the offline behaviour with a Gaussian Mixture and uses a hyper-Markov decision process and a hyper Q-function to select the most promising mode per state, followed by weighted imitation learning on that mode. The approach comes with policy-improvement guarantees and achieves state-of-the-art results on the D4RL benchmarks, especially in highly multi-modal regimes. By avoiding full multi-modal distribution modeling, LOM offers a simple yet powerful strategy for leveraging diverse offline data.

Abstract

Offline reinforcement learning (RL) seeks to learn optimal policies from static datasets without interacting with the environment. A common challenge is handling multi-modal action distributions, where multiple behaviours are represented in the data. Existing methods often assume unimodal behaviour policies, leading to suboptimal performance when this assumption is violated. We propose weighted imitation Learning on One Mode (LOM), a novel approach that focuses on learning from a single, promising mode of the behaviour policy. By using a Gaussian mixture model to identify modes and selecting the best mode based on expected returns, LOM avoids the pitfalls of averaging over conflicting actions. Theoretically, we show that LOM improves performance while maintaining simplicity in policy learning. Empirically, LOM outperforms existing methods on standard D4RL benchmarks and demonstrates its effectiveness in complex, multi-modal scenarios.

Paper Structure

This paper contains 32 sections, 4 theorems, 43 equations, 5 figures, 3 tables, 1 algorithm.

Key Result

Proposition 1

The hyper Q-function can be linked to the standard value function $Q^\pi(s, a)$ via: The proof can be found in Appendix hyper_value_function_proof.

Figures (5)

  • Figure 1: The three steps of LOM. (1) Learn a network producing the parameters of a GMM to model the behaviour policy. (2) Evaluate each mode via the expected return of its actions and then select the optimal mode $\phi^1$. (3) Sample actions from $\phi^1$ for weighted imitation learning.
  • Figure 2: Comparative study in the FetchReach task with highly multi-modal datasets. (a) The FetchReach robot is tasked with reaching one of four specified goals using an expert dataset. The robot arm receives a reward of $2$ for reaching the goal in the first quadrant and $1$ for reaching any of the other three goals. The dataset contains actions directed toward all four goals, with conflicting directions. (b) Action distribution learned by behaviour cloning using a unimodal Gaussian policy model. (c) Action distribution learned by MDN using a GMM policy model. (d) Action distribution learned by AWAC, which applies weighted imitation learning over the entire action distribution using a unimodal Gaussian policy. (e) Action distribution learned by LOM.
  • Figure 3: Normalised scores for varying numbers of Gaussian components in the medium replay and full replay datasets.
  • Figure 4: Action modes learned by MDN with varying numbers of mixtures. Red dots represent samples from the highest-reward mode. The original actions are clustered around $(1, 1)$, $(1, -1)$, $(-1, -1)$, and $(-1, 1)$ but do not extend beyond these points. In (d) and (e), the red dots collapse into a single point in the first quadrant due to the small standard deviation of the mode.
  • Figure 5: An example of modelling the multi-modal behaviour policy in a one-step MDP. The $x$-axis represents the state, and the $y$-axis represents the corresponding multi-modal actions. (a) shows the action distribution from the offline dataset. (b)-(e) illustrate the action distributions learned by a Gaussian model, a conditional VAE, a conditional GAN, and an MDN, respectively. (f) shows the action distribution of the offline dataset with rewards, where actions in the first and third quadrants receive a reward of $1$, and others receive $0$. (g)-(i) illustrate the action distributions learned by weighted Gaussian, weighted conditional VAE, and weighted conditional GAN models, respectively. (j) shows the action distribution learned by the top 10 of the 20 MDN components, ranked by a hyper Q-function.

Theorems & Definitions (7)

  • Proposition 1
  • Theorem 1
  • Theorem 2
  • Theorem 3
  • proof
  • proof
  • proof