Table of Contents
Fetching ...

Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks

Rui Hu, Yifan Zhang, Zhuoran Li, Longbo Huang

TL;DR

It is rigorously proved that distinct regression losses correspond to specific divergence measures, enabling us to design and analyze regression losses according to the desired properties of the corresponding divergence measures and examine two key properties: zero-forcing and zero-avoiding.

Abstract

Generative Flow Networks (GFlowNets) are a novel class of generative models designed to sample from unnormalized distributions and have found applications in various important tasks, attracting great research interest in their training algorithms. In general, GFlowNets are trained by fitting the forward flow to the backward flow on sampled training objects. Prior work focused on the choice of training objects, parameterizations, sampling and resampling strategies, and backward policies, aiming to enhance credit assignment, exploration, or exploitation of the training process. However, the choice of regression loss, which can highly influence the exploration and exploitation behavior of the under-training policy, has been overlooked. Due to the lack of theoretical understanding for choosing an appropriate regression loss, most existing algorithms train the flow network by minimizing the squared error of the forward and backward flows in log-space, i.e., using the quadratic regression loss. In this work, we rigorously prove that distinct regression losses correspond to specific divergence measures, enabling us to design and analyze regression losses according to the desired properties of the corresponding divergence measures. Specifically, we examine two key properties: zero-forcing and zero-avoiding, where the former promotes exploitation and higher rewards, and the latter encourages exploration and enhances diversity. Based on our theoretical framework, we propose three novel regression losses, namely, Shifted-Cosh, Linex(1/2), and Linex(1). We evaluate them across three benchmarks: hyper-grid, bit-sequence generation, and molecule generation. Our proposed losses are compatible with most existing training algorithms, and significantly improve the performances of the algorithms concerning convergence speed, sample diversity, and robustness.

Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks

TL;DR

It is rigorously proved that distinct regression losses correspond to specific divergence measures, enabling us to design and analyze regression losses according to the desired properties of the corresponding divergence measures and examine two key properties: zero-forcing and zero-avoiding.

Abstract

Generative Flow Networks (GFlowNets) are a novel class of generative models designed to sample from unnormalized distributions and have found applications in various important tasks, attracting great research interest in their training algorithms. In general, GFlowNets are trained by fitting the forward flow to the backward flow on sampled training objects. Prior work focused on the choice of training objects, parameterizations, sampling and resampling strategies, and backward policies, aiming to enhance credit assignment, exploration, or exploitation of the training process. However, the choice of regression loss, which can highly influence the exploration and exploitation behavior of the under-training policy, has been overlooked. Due to the lack of theoretical understanding for choosing an appropriate regression loss, most existing algorithms train the flow network by minimizing the squared error of the forward and backward flows in log-space, i.e., using the quadratic regression loss. In this work, we rigorously prove that distinct regression losses correspond to specific divergence measures, enabling us to design and analyze regression losses according to the desired properties of the corresponding divergence measures. Specifically, we examine two key properties: zero-forcing and zero-avoiding, where the former promotes exploitation and higher rewards, and the latter encourages exploration and enhances diversity. Based on our theoretical framework, we propose three novel regression losses, namely, Shifted-Cosh, Linex(1/2), and Linex(1). We evaluate them across three benchmarks: hyper-grid, bit-sequence generation, and molecule generation. Our proposed losses are compatible with most existing training algorithms, and significantly improve the performances of the algorithms concerning convergence speed, sample diversity, and robustness.
Paper Structure (35 sections, 6 theorems, 31 equations, 5 figures, 5 tables)

This paper contains 35 sections, 6 theorems, 31 equations, 5 figures, 5 tables.

Key Result

Theorem 4.1

Let $\theta$ be the parameters for forward policies. For each minimal cut $C\in\mathcal{C}$, the restrictions of both forward and backward flow functions on $C$ can be viewed as unnormalized distributions over it, denoted as $\widehat{p}^C_F$ and $\widehat{p}^C_B$, respectively. If there exists $w:\

Figures (5)

  • Figure 1: An illustration of our main theoretical results: the unified framework for GFlowNet training algorithms and the correspondence between regression losses over forward and backward flows on training objects and $f$-divergences between the two flows over minimal cuts.
  • Figure 2: Our proposed regression losses and their properties.
  • Figure 3: Hyper-grid results: the empirical L1 distance between $P_T$ and $P_R$.
  • Figure 4: The number of modes found by the algorithm during training.
  • Figure 5: Molecule generation results. Top: Average reward and pair-wise similarities of all $200k$ generated molecules during each training episode. The similarities are calculated among a randomly chosen subset of $1000$ molecules. Bottom: Average reward and pair-wise similarities of the top $k$ generated molecules during each training episode.

Theorems & Definitions (9)

  • Theorem 4.1
  • Remark 4.2
  • Remark 4.3
  • Proposition 4.4: liese2006divergences
  • Proposition 4.5: liese2006divergences
  • Definition 4.6
  • Theorem 4.7
  • Theorem B.1: An extension of Theorem \ref{['thm:main-theorem']}
  • Theorem D.1