Table of Contents
Fetching ...

MGDA: Model-based Goal Data Augmentation for Offline Goal-conditioned Weighted Supervised Learning

Xing Lei, Xuetao Zhang, Donglin Wang

TL;DR

The paper addresses offline goal-conditioned RL and the stitching limitation of GCWSL by introducing a principled model-based data augmentation method, MGDA, that uses a learned dynamics model with local Lipschitz control to generate plausible augmented goals.MGDA is guided by three unified principles—Goal Diversity, Action Optimality, and Goal Reachability—and comes with theoretical guarantees showing it approximates a one-step stitching distribution $p^{1-step}(g|s,a)$ within an error bound $\mathcal{O}(\epsilon_k L_1)$.The approach is validated on state-based and vision-based offline maze datasets, where MGDA-enhanced GCWSL improves stitching performance over existing augmentation methods, with ablations highlighting the critical role of the LLC constraint.Overall, MGDA offers a principled, scalable augmentation framework that enhances trajectory stitching in offline GCWSL and has potential applicability to other goal-conditioned supervised-learning paradigms.

Abstract

Recently, a state-of-the-art family of algorithms, known as Goal-Conditioned Weighted Supervised Learning (GCWSL) methods, has been introduced to tackle challenges in offline goal-conditioned reinforcement learning (RL). GCWSL optimizes a lower bound of the goal-conditioned RL objective and has demonstrated outstanding performance across diverse goal-reaching tasks, providing a simple, effective, and stable solution. However, prior research has identified a critical limitation of GCWSL: the lack of trajectory stitching capabilities. To address this, goal data augmentation strategies have been proposed to enhance these methods. Nevertheless, existing techniques often struggle to sample suitable augmented goals for GCWSL effectively. In this paper, we establish unified principles for goal data augmentation, focusing on goal diversity, action optimality, and goal reachability. Based on these principles, we propose a Model-based Goal Data Augmentation (MGDA) approach, which leverages a learned dynamics model to sample more suitable augmented goals. MGDA uniquely incorporates the local Lipschitz continuity assumption within the learned model to mitigate the impact of compounding errors. Empirical results show that MGDA significantly enhances the performance of GCWSL methods on both state-based and vision-based maze datasets, surpassing previous goal data augmentation techniques in improving stitching capabilities.

MGDA: Model-based Goal Data Augmentation for Offline Goal-conditioned Weighted Supervised Learning

TL;DR

The paper addresses offline goal-conditioned RL and the stitching limitation of GCWSL by introducing a principled model-based data augmentation method, MGDA, that uses a learned dynamics model with local Lipschitz control to generate plausible augmented goals.MGDA is guided by three unified principles—Goal Diversity, Action Optimality, and Goal Reachability—and comes with theoretical guarantees showing it approximates a one-step stitching distribution $p^{1-step}(g|s,a)$ within an error bound $\mathcal{O}(\epsilon_k L_1)$.The approach is validated on state-based and vision-based offline maze datasets, where MGDA-enhanced GCWSL improves stitching performance over existing augmentation methods, with ablations highlighting the critical role of the LLC constraint.Overall, MGDA offers a principled, scalable augmentation framework that enhances trajectory stitching in offline GCWSL and has potential applicability to other goal-conditioned supervised-learning paradigms.

Abstract

Recently, a state-of-the-art family of algorithms, known as Goal-Conditioned Weighted Supervised Learning (GCWSL) methods, has been introduced to tackle challenges in offline goal-conditioned reinforcement learning (RL). GCWSL optimizes a lower bound of the goal-conditioned RL objective and has demonstrated outstanding performance across diverse goal-reaching tasks, providing a simple, effective, and stable solution. However, prior research has identified a critical limitation of GCWSL: the lack of trajectory stitching capabilities. To address this, goal data augmentation strategies have been proposed to enhance these methods. Nevertheless, existing techniques often struggle to sample suitable augmented goals for GCWSL effectively. In this paper, we establish unified principles for goal data augmentation, focusing on goal diversity, action optimality, and goal reachability. Based on these principles, we propose a Model-based Goal Data Augmentation (MGDA) approach, which leverages a learned dynamics model to sample more suitable augmented goals. MGDA uniquely incorporates the local Lipschitz continuity assumption within the learned model to mitigate the impact of compounding errors. Empirical results show that MGDA significantly enhances the performance of GCWSL methods on both state-based and vision-based maze datasets, surpassing previous goal data augmentation techniques in improving stitching capabilities.

Paper Structure

This paper contains 17 sections, 2 theorems, 13 equations, 6 figures, 4 tables, 1 algorithm.

Key Result

Theorem 1

As described in Figure fig:augment goal, under the assumption of local Lipschitz continuity, when the error generated by training the model is $\epsilon$, the 1-step residual dynamics model $\hat{f}$ is subject to the following boundary for predicting the correct nearby states $s_n$ of goal $g$. where $K$ is the Lipschitz constants for true environment dynamics $f$ and $\Delta(\lambda_n)$ is the

Figures (6)

  • Figure 1: Counter examples of generalized principles and related goal data augmentation methods. The states $s^0_0$ through $s^5_0$ correspond to trajectories $\tau_0$ through $\tau_5$, respectively. The goals $g^{0}_{t}$ through $g^{5}_{t}$ represent the relabeled goals for each respective trajectory. Note that $s^{0}_0$ is equal to $s^{2}_0$ and red points denote intersection state of two trajectories. The light blue circles represent the nearby states of $g^{0}_{t}$ after k-means clustering within the TGDA ghugare2024closing . Compared to the original goal $g^{0}_{t}$, SGDA yang2023swapped randomly select all goals $g\in \left[g^{1}_{t},g^{2}_{t},g^{3}_{t},g^{4}_{t}, g^{5}_{t} \right]$ as augmented goals, while TGDA selects the goals $g\in \left[g^{1}_{t},g^{3}_{t},g^{4}_{t}\right]$ from later in the trajectory corresponding to the nearby states $s\in\left[s^{1}_{i},s^{2}_{j},s^{3}_{k}\right]$. Our MGDA method will select more appropriate goals $g\in \left[g^{1}_{t},g^{4}_{t}\right]$, building upon TGDA by avoiding unreachable goals.
  • Figure 2: (left) State searched by other goal data augmentation methods. (right) State searched by our MGDA, constrained by the dynamics model in its relationship to the goal. MGDA ensures that the searched state and goal $g$ satisfy the one-step transition criterion, thereby defining this searched state as a nearby state $s_n$.
  • Figure 3: Performance of the original GCWSL methods and the impact of incorporating different goal augmentation approaches in state-based datasets. We use the final mean success rate as the report. Error bars denote 95$\%$ bootstrap confidence intervals.
  • Figure 4: Performance comparison between the original GCWSL approach and its enhancement with MGDA on vision-based datasets. We also use the final mean success rate as the report. Error bars denote 95$\%$ bootstrap confidence intervals.
  • Figure 5: Ablation study on the local Lipschitz continuity (LLC) assumption. The results clearly show that the modified MSE version, incorporating LLC, outperforms the standard MSE-based dynamics model. This highlights the crucial role of the LLC assumption in enhancing performance.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Theorem 1: model smoothness
  • Theorem 2