Table of Contents
Fetching ...

Diffusion Generative Flow Samplers: Improving learning signals through partial trajectory optimization

Dinghuai Zhang, Ricky T. Q. Chen, Cheng-Hao Liu, Aaron Courville, Yoshua Bengio

TL;DR

This work presents Diffusion Generative Flow Samplers (DGFS), a sampling-based framework where the learning process can be tractably broken down into short partial trajectory segments, via parameterizing an additional"flow function".

Abstract

We tackle the problem of sampling from intractable high-dimensional density functions, a fundamental task that often appears in machine learning and statistics. We extend recent sampling-based approaches that leverage controlled stochastic processes to model approximate samples from these target densities. The main drawback of these approaches is that the training objective requires full trajectories to compute, resulting in sluggish credit assignment issues due to use of entire trajectories and a learning signal present only at the terminal time. In this work, we present Diffusion Generative Flow Samplers (DGFS), a sampling-based framework where the learning process can be tractably broken down into short partial trajectory segments, via parameterizing an additional "flow function". Our method takes inspiration from the theory developed for generative flow networks (GFlowNets), allowing us to make use of intermediate learning signals. Through various challenging experiments, we demonstrate that DGFS achieves more accurate estimates of the normalization constant than closely-related prior methods.

Diffusion Generative Flow Samplers: Improving learning signals through partial trajectory optimization

TL;DR

This work presents Diffusion Generative Flow Samplers (DGFS), a sampling-based framework where the learning process can be tractably broken down into short partial trajectory segments, via parameterizing an additional"flow function".

Abstract

We tackle the problem of sampling from intractable high-dimensional density functions, a fundamental task that often appears in machine learning and statistics. We extend recent sampling-based approaches that leverage controlled stochastic processes to model approximate samples from these target densities. The main drawback of these approaches is that the training objective requires full trajectories to compute, resulting in sluggish credit assignment issues due to use of entire trajectories and a learning signal present only at the terminal time. In this work, we present Diffusion Generative Flow Samplers (DGFS), a sampling-based framework where the learning process can be tractably broken down into short partial trajectory segments, via parameterizing an additional "flow function". Our method takes inspiration from the theory developed for generative flow networks (GFlowNets), allowing us to make use of intermediate learning signals. Through various challenging experiments, we demonstrate that DGFS achieves more accurate estimates of the normalization constant than closely-related prior methods.
Paper Structure (37 sections, 31 equations, 10 figures, 4 tables, 1 algorithm)

This paper contains 37 sections, 31 equations, 10 figures, 4 tables, 1 algorithm.

Figures (10)

  • Figure 1: Illustration of the DGFS algorithm. The proposed method can update from partial trajectories (colored segments) as well as intermediate signals (gray arrows).
  • Figure 2: Gradient variance of DGFS and PIS, explaining the better training effects of DGFS.
  • Figure 3: The learned DGFS flow function and the ground truth samples from target process at different diffusion steps. This shows DGFS is able to learn flow functions correctly.
  • Figure 4: Manywell plots. DGFS and DDS but not PIS recover all modes.
  • Figure 5: MoG visualization of DGFS and other diffusion-based samplers shows that DGFS could capture the diverse modes well. The contours display the landscape of the target density.
  • ...and 5 more figures