Table of Contents
Fetching ...

Learning Likelihoods with Conditional Normalizing Flows

Christina Winkler, Daniel Worrall, Emiel Hoogeboom, Max Welling

TL;DR

Conditional normalizing flows are proposed to model p_Y|X for high-dimensional outputs, capturing inter-output correlations and multimodality without hand-designed per-pixel losses. The method conditions both the base density and the invertible mapping on X, enabling efficient sampling and exact likelihoods, and is trained in z-space. Experiments on super-resolution and retinal vessel segmentation show competitive likelihoods and traditional metrics, with qualitative evidence of crisper details and calibrated probability estimates. Overall, CNFs offer a flexible, probabilistic framework for structured prediction that avoids mode collapse and training instability typical of some alternative approaches.

Abstract

Normalizing Flows (NFs) are able to model complicated distributions p(y) with strong inter-dimensional correlations and high multimodality by transforming a simple base density p(z) through an invertible neural network under the change of variables formula. Such behavior is desirable in multivariate structured prediction tasks, where handcrafted per-pixel loss-based methods inadequately capture strong correlations between output dimensions. We present a study of conditional normalizing flows (CNFs), a class of NFs where the base density to output space mapping is conditioned on an input x, to model conditional densities p(y|x). CNFs are efficient in sampling and inference, they can be trained with a likelihood-based objective, and CNFs, being generative flows, do not suffer from mode collapse or training instabilities. We provide an effective method to train continuous CNFs for binary problems and in particular, we apply these CNFs to super-resolution and vessel segmentation tasks demonstrating competitive performance on standard benchmark datasets in terms of likelihood and conventional metrics.

Learning Likelihoods with Conditional Normalizing Flows

TL;DR

Conditional normalizing flows are proposed to model p_Y|X for high-dimensional outputs, capturing inter-output correlations and multimodality without hand-designed per-pixel losses. The method conditions both the base density and the invertible mapping on X, enabling efficient sampling and exact likelihoods, and is trained in z-space. Experiments on super-resolution and retinal vessel segmentation show competitive likelihoods and traditional metrics, with qualitative evidence of crisper details and calibrated probability estimates. Overall, CNFs offer a flexible, probabilistic framework for structured prediction that avoids mode collapse and training instability typical of some alternative approaches.

Abstract

Normalizing Flows (NFs) are able to model complicated distributions p(y) with strong inter-dimensional correlations and high multimodality by transforming a simple base density p(z) through an invertible neural network under the change of variables formula. Such behavior is desirable in multivariate structured prediction tasks, where handcrafted per-pixel loss-based methods inadequately capture strong correlations between output dimensions. We present a study of conditional normalizing flows (CNFs), a class of NFs where the base density to output space mapping is conditioned on an input x, to model conditional densities p(y|x). CNFs are efficient in sampling and inference, they can be trained with a likelihood-based objective, and CNFs, being generative flows, do not suffer from mode collapse or training instabilities. We provide an effective method to train continuous CNFs for binary problems and in particular, we apply these CNFs to super-resolution and vessel segmentation tasks demonstrating competitive performance on standard benchmark datasets in terms of likelihood and conventional metrics.

Paper Structure

This paper contains 28 sections, 9 equations, 9 figures, 8 tables.

Figures (9)

  • Figure 1: Diagram of our model in the train and sampling phases. Solid lines represent deterministic mappings and dashed lines represent sampling. The conditioning variable enters the network in base density $p({\mathbf{z}}|{\mathbf{x}})$ and the bijective mappings $f({\mathbf{y}},{\mathbf{x}})$.
  • Figure 2: Super resolution results on the Imagenet64 test data. Samples are taken from the CNF $x_{hr} \sim p(x_{hr} | x_{lr})$ and the mode is visualized for the factorized baseline model. Best viewed electronically.
  • Figure 3: Conditional samples from the CNF (ours) for sampling temperatures $\{0.,0.5,1.0\}$ and the factorized discrete baseline for 2x upscaling. Conditioning image is a baboon from Set14 test set. Both models were trained on ImageNet64. Best viewed electronically.
  • Figure 4: Example of retinal segmentations using DRIU, our likelihood baseline trained with the same loss, and our CNF. For the CNF, the mean of 100 samples is visualized. Notice that our segmentations more accurately capture the vessel width, which is overdilated in the DRIU and factored models.
  • Figure 5: Here we show two visualizations of the same data. Left: We show the PR-curves generated from a sweeping threshold on soft images output by each listed method. Maximal F-scores for each curve are shown as circles with the green lines indicating constant F-score. We see that our method beats all traditional methods and is on par with DRIU, which unlike us was pretrained on Imagenet. Right: We show a scatter plot in PR-space of samples drawn from each model. To draw samples from the all factored models, we sample images from a factored Bernoulli with a mean as the soft image. We see that the DRIU and HED models, while having good precision, have poor recall in this regime. This indicates that while the output of their networks produce a good ranking of probabilities, the values of the probabilities are poorly calibrated. For us, we drop in precision slightly, but gain greatly in terms of recall, indicating that our samples are drawn from a better calibrated distribution, overlapping significantly with the Human distribution.
  • ...and 4 more figures