Table of Contents
Fetching ...

On Divergence Measures for Training GFlowNets

Tiago da Silva, Eliezer de Souza da Silva, Diego Mesquita

TL;DR

Control variates based on the REINFORCE leave-one-out and score-matching estimators to reduce the variance of the learning objectives' gradients are designed to narrow the gap between GFlowNets training and generalized variational approximations.

Abstract

Generative Flow Networks (GFlowNets) are amortized inference models designed to sample from unnormalized distributions over composable objects, with applications in generative modeling for tasks in fields such as causal discovery, NLP, and drug discovery. Traditionally, the training procedure for GFlowNets seeks to minimize the expected log-squared difference between a proposal (forward policy) and a target (backward policy) distribution, which enforces certain flow-matching conditions. While this training procedure is closely related to variational inference (VI), directly attempting standard Kullback-Leibler (KL) divergence minimization can lead to proven biased and potentially high-variance estimators. Therefore, we first review four divergence measures, namely, Renyi-$α$'s, Tsallis-$α$'s, reverse and forward KL's, and design statistically efficient estimators for their stochastic gradients in the context of training GFlowNets. Then, we verify that properly minimizing these divergences yields a provably correct and empirically effective training scheme, often leading to significantly faster convergence than previously proposed optimization. To achieve this, we design control variates based on the REINFORCE leave-one-out and score-matching estimators to reduce the variance of the learning objectives' gradients. Our work contributes by narrowing the gap between GFlowNets training and generalized variational approximations, paving the way for algorithmic ideas informed by the divergence minimization viewpoint.

On Divergence Measures for Training GFlowNets

TL;DR

Control variates based on the REINFORCE leave-one-out and score-matching estimators to reduce the variance of the learning objectives' gradients are designed to narrow the gap between GFlowNets training and generalized variational approximations.

Abstract

Generative Flow Networks (GFlowNets) are amortized inference models designed to sample from unnormalized distributions over composable objects, with applications in generative modeling for tasks in fields such as causal discovery, NLP, and drug discovery. Traditionally, the training procedure for GFlowNets seeks to minimize the expected log-squared difference between a proposal (forward policy) and a target (backward policy) distribution, which enforces certain flow-matching conditions. While this training procedure is closely related to variational inference (VI), directly attempting standard Kullback-Leibler (KL) divergence minimization can lead to proven biased and potentially high-variance estimators. Therefore, we first review four divergence measures, namely, Renyi-'s, Tsallis-'s, reverse and forward KL's, and design statistically efficient estimators for their stochastic gradients in the context of training GFlowNets. Then, we verify that properly minimizing these divergences yields a provably correct and empirically effective training scheme, often leading to significantly faster convergence than previously proposed optimization. To achieve this, we design control variates based on the REINFORCE leave-one-out and score-matching estimators to reduce the variance of the learning objectives' gradients. Our work contributes by narrowing the gap between GFlowNets training and generalized variational approximations, paving the way for algorithmic ideas informed by the divergence minimization viewpoint.

Paper Structure

This paper contains 20 sections, 4 theorems, 32 equations, 7 figures, 1 table.

Key Result

Proposition 1

Let $\mathcal{L}_{TB}(\tau ; \theta) = \left(\log Z p_{F_\theta}(\tau | s_{o}; \theta)/r(x) p_{B}(\tau | x)\right)^{2}$ and $p_{B}(\tau) = \frac{r(x)}{Z} p_{B}(s_{n-1:o} | x)$ for $\tau = (s_{o}, \dots, s_{n - 1}, x, s_{f})$. Then, where $\mathcal{D}_{KL} [ p_{F_\theta} || p_{B} ] = \mathbb{E}_{\tau \sim P_{F}(s_{o}, \cdot)} \left[ \log p_{F_\theta}(\tau | s_{o}; \theta)/p_{B}(\tau)\right]$ is

Figures (7)

  • Figure 1: Mode-seeking ($\alpha = 2$) versus mass-covering ($\alpha = -2$) behaviour in $\alpha$-divergences.
  • Figure 2: Variance of the estimated gradients as a function of the trajectories' batch size. Our control variates greatly reduce the estimator's variance, even for relatively small batch sizes.
  • Figure 3: Divergence-based learning objectives often lead to faster training than TB loss. Notably, contrasting with the experiments of malkin2023gflownets, there is no single best loss function always conducting to the fastest convergence rate, and minimizing well-known divergence measures is often on par with or better than minimizing the TB loss in terms of convergence speed. Results were averaged across three different seeds. Also, we fix $\alpha = 0.5$ for both Tsallis-$\alpha$ and Renyi-$\alpha$ divergences.
  • Figure 4: Learned distributions for the banana-shaped target. Tsallis-$\alpha$, Renyi-$\alpha$ and for. KL leads to a better model than TB and Rev. KL, which behave similarly --- as predicted by \ref{['prop:aaa']}.
  • Figure 5: Learning curves for different objective functions in the task of set generation. The reduced variance of the gradient estimates notably increases training stability and speed.
  • ...and 2 more figures

Theorems & Definitions (10)

  • Definition 1: Measurable pointed DAG theory
  • Definition 2: GFlowNets theory
  • Definition 3: Trajectory balance condition
  • Definition 4: Detailed balance condition
  • Proposition 1: TB loss- and KL divergence gradients for topological spaces
  • Definition 5: Renyi-$\alpha$ and Tsallis-$\alpha$ divergences
  • Lemma 1: Gradients for $\mathcal{R}_{\alpha}$ and $\mathcal{T}_{\alpha}$
  • Definition 6: Forward and reverse KL
  • Lemma 2: Gradients for the KL divergence
  • Proposition 2: Control variate for gradients