Table of Contents
Fetching ...

Diffusion Self-Weighted Guidance for Offline Reinforcement Learning

Augusto Tagle, Javier Ruiz-del-Solar, Felipe Tobar

TL;DR

<3-5 sentence high-level summary> SWG introduces a diffusion model that jointly samples actions and their weights by augmenting the data to Z=(A,W), enabling exact guidance from a single DM without external networks. The method derives the target score via the data-prediction formula, allowing self-guided sampling that approximates the target policy $\, \pi(a|s)$ in offline RL. Empirically, SWG achieves competitive results on D4RL benchmarks and demonstrates robust toy-case sampling and informative ablations on weight formulations and guidance scale, with SWG-R offering additional gains through resampling. While offering a simpler training pipeline than prior DM-guided offline RL approaches, SWG shows limitations in mixed-action datasets and reveals future directions such as dedicated weight modules and scalable solvers.

Abstract

Offline reinforcement learning (RL) recovers the optimal policy $π$ given historical observations of an agent. In practice, $π$ is modeled as a weighted version of the agent's behavior policy $μ$, using a weight function $w$ working as a critic of the agent's behavior. Though recent approaches to offline RL based on diffusion models have exhibited promising results, the computation of the required scores is challenging due to their dependence on the unknown $w$. In this work, we alleviate this issue by constructing a diffusion over both the actions and the weights. With the proposed setting, the required scores are directly obtained from the diffusion model without learning extra networks. Our main conceptual contribution is a novel guidance method, where guidance (which is a function of $w$) comes from the same diffusion model, therefore, our proposal is termed Self-Weighted Guidance (SWG). We show that SWG generates samples from the desired distribution on toy examples and performs on par with state-of-the-art methods on D4RL's challenging environments, while maintaining a streamlined training pipeline. We further validate SWG through ablation studies on weight formulations and scalability.

Diffusion Self-Weighted Guidance for Offline Reinforcement Learning

TL;DR

<3-5 sentence high-level summary> SWG introduces a diffusion model that jointly samples actions and their weights by augmenting the data to Z=(A,W), enabling exact guidance from a single DM without external networks. The method derives the target score via the data-prediction formula, allowing self-guided sampling that approximates the target policy in offline RL. Empirically, SWG achieves competitive results on D4RL benchmarks and demonstrates robust toy-case sampling and informative ablations on weight formulations and guidance scale, with SWG-R offering additional gains through resampling. While offering a simpler training pipeline than prior DM-guided offline RL approaches, SWG shows limitations in mixed-action datasets and reveals future directions such as dedicated weight modules and scalable solvers.

Abstract

Offline reinforcement learning (RL) recovers the optimal policy given historical observations of an agent. In practice, is modeled as a weighted version of the agent's behavior policy , using a weight function working as a critic of the agent's behavior. Though recent approaches to offline RL based on diffusion models have exhibited promising results, the computation of the required scores is challenging due to their dependence on the unknown . In this work, we alleviate this issue by constructing a diffusion over both the actions and the weights. With the proposed setting, the required scores are directly obtained from the diffusion model without learning extra networks. Our main conceptual contribution is a novel guidance method, where guidance (which is a function of ) comes from the same diffusion model, therefore, our proposal is termed Self-Weighted Guidance (SWG). We show that SWG generates samples from the desired distribution on toy examples and performs on par with state-of-the-art methods on D4RL's challenging environments, while maintaining a streamlined training pipeline. We further validate SWG through ablation studies on weight formulations and scalability.

Paper Structure

This paper contains 28 sections, 2 theorems, 34 equations, 6 figures, 10 tables, 3 algorithms.

Key Result

Proposition 3.3

The required score to sample from the target distribution in equation eq:extended_weighted_eq is given by

Figures (6)

  • Figure 1: Comparison between the proposed SWG approach (left) and the standard Diffusion Model with a separate Guidance Network (right). SWG jointly learns actions and weights within a single diffusion model, whereas standard approaches in diffusion for RL trains an additional guidance network to steer the diffusion model. The figure represents one forward pass to the diffusion model during training. Recall that $\mathbf{z_0}=[\mathbf{a_0},w_0]$.
  • Figure 2: SWG applied to a spiral distribution for different temperature coefficients $\beta$. Top: samples generated by SWG. Bottom: ground truth samples. When $\beta=0$, we have $w(\mathbf{a})=1$ and thus the target distribution matches the data distribution.
  • Figure 3: SWG applied to an 8-Gaussian mixture distribution for different temperature coefficients $\beta$. Top: samples generated by SWG. Bottom: ground truth samples. When $\beta=0$, we have $w(\mathbf{a})=1$ and thus the target distribution matches the data distribution.
  • Figure 4: SWG depth scaling: inference time (s) vs number of layers (mean $\pm$ 2 standard deviations over 10,000 runs).
  • Figure 5: SWG width scaling: inference time (s) vs hidden dimension (mean $\pm$ 2 standard deviations over 10,000 runs).
  • ...and 1 more figures

Theorems & Definitions (7)

  • Remark 2.1
  • Definition 3.1: Extraction function
  • Remark 3.2
  • Proposition 3.3
  • Remark 3.4
  • Theorem 3.5
  • Remark 5.1