Table of Contents
Fetching ...

Adaptive Symmetrization of the KL Divergence

Omri Ben-Dov, Luiz F. O. Chamon

TL;DR

This work sets out to develop a new approach to minimize the Jeffreys divergence, and uses a proxy model whose goal is not only to fit the data, but also to assist in optimizing the Jeffreys divergence of the main model.

Abstract

Many tasks in machine learning can be described as or reduced to learning a probability distribution given a finite set of samples. A common approach is to minimize a statistical divergence between the (empirical) data distribution and a parameterized distribution, e.g., a normalizing flow (NF) or an energy-based model (EBM). In this context, the forward KL divergence is a ubiquitous due to its tractability, though its asymmetry may prevent capturing some properties of the target distribution. Symmetric alternatives involve brittle min-max formulations and adversarial training (e.g., generative adversarial networks) or evaluating the reverse KL divergence, as is the case for the symmetric Jeffreys divergence, which is challenging to compute from samples. This work sets out to develop a new approach to minimize the Jeffreys divergence. To do so, it uses a proxy model whose goal is not only to fit the data, but also to assist in optimizing the Jeffreys divergence of the main model. This joint training task is formulated as a constrained optimization problem to obtain a practical algorithm that adapts the models priorities throughout training. We illustrate how this framework can be used to combine the advantages of NFs and EBMs in tasks such as density estimation, image generation, and simulation-based inference.

Adaptive Symmetrization of the KL Divergence

TL;DR

This work sets out to develop a new approach to minimize the Jeffreys divergence, and uses a proxy model whose goal is not only to fit the data, but also to assist in optimizing the Jeffreys divergence of the main model.

Abstract

Many tasks in machine learning can be described as or reduced to learning a probability distribution given a finite set of samples. A common approach is to minimize a statistical divergence between the (empirical) data distribution and a parameterized distribution, e.g., a normalizing flow (NF) or an energy-based model (EBM). In this context, the forward KL divergence is a ubiquitous due to its tractability, though its asymmetry may prevent capturing some properties of the target distribution. Symmetric alternatives involve brittle min-max formulations and adversarial training (e.g., generative adversarial networks) or evaluating the reverse KL divergence, as is the case for the symmetric Jeffreys divergence, which is challenging to compute from samples. This work sets out to develop a new approach to minimize the Jeffreys divergence. To do so, it uses a proxy model whose goal is not only to fit the data, but also to assist in optimizing the Jeffreys divergence of the main model. This joint training task is formulated as a constrained optimization problem to obtain a practical algorithm that adapts the models priorities throughout training. We illustrate how this framework can be used to combine the advantages of NFs and EBMs in tasks such as density estimation, image generation, and simulation-based inference.

Paper Structure

This paper contains 43 sections, 1 theorem, 36 equations, 9 figures, 8 algorithms.

Key Result

Theorem 1

Let $p_{\theta}{},q_{\psi}{} > 0$ and $h$ be a convex and Lipschitz continuous function. Suppose there exists $\nu \geq 0$ such that for each $p {\in} \bar{\mathcal{H}}$, the closed convex hull of $\mathcal{H}$, there exists $p_\theta {\in} \mathcal{H}$ such that $\|p_\theta - p\|_\text{TV} \leq \n

Figures (9)

  • Figure 1: Illustration of different training dynamics, where the goal is to train $q_{\psi}$ and $p_{\theta}$ to match the data distribution $\pi$. Each point represents a probability distribution in conceptual distribution space, and the arrows indicate the direction of gradient descent (GD). (a) In maximum likelihood estimation (MLE), GD minimizes the divergence $D_{\mathrm{KL}}\left(\pi\parallel p_{\theta}\right)$. In practice, GD minimizes a finite sample estimation of the KL, rather than the true KL. (b) GD of generative adversarial networks (GAN) pulls the generator $p_{\theta}$ towards the discriminator $q_{\psi}$ and the discriminator towards $\pi$, but it also pushes $q_{\psi}$ away from $p_{\theta}$. This repelling dynamic makes GAN training unstable. (c) In our framework, GD pulls $p_{\theta}$ and $q_{\psi}$ towards each other, while also pulling both towards $\pi$. The mutual attraction stabilizes training.
  • Figure 2: Solving the dual problem (solid black line) achieves better results than 25 weight configurations of the weighted problem \ref{['P:penalty']} (dashed colored lines) over a synthetic Gaussian mixture dataset. (a) The $p_{\theta}$ of dual problem has better and more stable NLL than any configuration of the weighted problem. (b) The partition function of the EBM of the dual problem is stable at unity, while the weighted problem may reach large unstable values.
  • Figure 3: Our dual is more stable than NF and WGAN and outperforms them on a synthetic 2D GMM dataset. (a) The dual problem (solid lines) has a stable KL divergence lower than NF (dashed line). NF's KL is increasing since it overfits to a finite distribution rather than the test GMM distribution. (b) WGANs (dashed lines) is unstable when increasing the learning rate, while our method (solid lines) is stable. Triangles represent high learning rate, and X represents low learning rate.
  • Figure 4: Our framework is able to accurately learn the density of various 2D datasets. The left column (a,c) present the lowest NLL each model achieved during training ("Best") and the NLL at the end of training ("Last"). The error bars represent the standard deviation over 5 seeded runs, demonstrating that our method outperforms NF in terms of values and consistency. The right column (b,d) qualitatively displays the finite dataset (left) and the learned density values from of $p_{\theta}$ and the quasi-normalized $\tilde{p}_{\psi}$ (middle and right respectively). The qualitative density maps show that both models perfectly capture the shape and all modes.
  • Figure 5: Our $p_{\theta}$ learns a comparable 100D encoded space of CelebA as NF. (a) reports the FID and qualitative samples generated from our $p_{\theta}$, and (b) reports the same for NF. We used the same seed to generate the sample, which explains the similarity between the images.
  • ...and 4 more figures

Theorems & Definitions (2)

  • Theorem 1
  • proof