Table of Contents
Fetching ...

What's in a Prior? Learned Proximal Networks for Inverse Problems

Zhenghan Fang, Sam Buchanan, Jeremias Sulam

TL;DR

This paper tackles ill-posed inverse problems by introducing Learned Proximal Networks (LPN), which parameterize proximal operators of learned, potentially nonconvex regularizers. A new training objective, proximal matching, enforces learning of the MAP denoiser corresponding to the log-prior, enabling recovery of the true prior from data. The authors prove that LPNs implement exact proximal operators, derive a method to recover the associated regularizer, and establish convergence guarantees for plug-and-play schemes using LPNs. Empirically, LPNs achieve state-of-the-art unsupervised reconstructions on CelebA deblurring and Mayo-CT, while providing interpretable priors and robust convergence properties, illustrating the practical impact for diverse inverse problems.

Abstract

Proximal operators are ubiquitous in inverse problems, commonly appearing as part of algorithmic strategies to regularize problems that are otherwise ill-posed. Modern deep learning models have been brought to bear for these tasks too, as in the framework of plug-and-play or deep unrolling, where they loosely resemble proximal operators. Yet, something essential is lost in employing these purely data-driven approaches: there is no guarantee that a general deep network represents the proximal operator of any function, nor is there any characterization of the function for which the network might provide some approximate proximal. This not only makes guaranteeing convergence of iterative schemes challenging but, more fundamentally, complicates the analysis of what has been learned by these networks about their training data. Herein we provide a framework to develop learned proximal networks (LPN), prove that they provide exact proximal operators for a data-driven nonconvex regularizer, and show how a new training strategy, dubbed proximal matching, provably promotes the recovery of the log-prior of the true data distribution. Such LPN provide general, unsupervised, expressive proximal operators that can be used for general inverse problems with convergence guarantees. We illustrate our results in a series of cases of increasing complexity, demonstrating that these models not only result in state-of-the-art performance, but provide a window into the resulting priors learned from data.

What's in a Prior? Learned Proximal Networks for Inverse Problems

TL;DR

This paper tackles ill-posed inverse problems by introducing Learned Proximal Networks (LPN), which parameterize proximal operators of learned, potentially nonconvex regularizers. A new training objective, proximal matching, enforces learning of the MAP denoiser corresponding to the log-prior, enabling recovery of the true prior from data. The authors prove that LPNs implement exact proximal operators, derive a method to recover the associated regularizer, and establish convergence guarantees for plug-and-play schemes using LPNs. Empirically, LPNs achieve state-of-the-art unsupervised reconstructions on CelebA deblurring and Mayo-CT, while providing interpretable priors and robust convergence properties, illustrating the practical impact for diverse inverse problems.

Abstract

Proximal operators are ubiquitous in inverse problems, commonly appearing as part of algorithmic strategies to regularize problems that are otherwise ill-posed. Modern deep learning models have been brought to bear for these tasks too, as in the framework of plug-and-play or deep unrolling, where they loosely resemble proximal operators. Yet, something essential is lost in employing these purely data-driven approaches: there is no guarantee that a general deep network represents the proximal operator of any function, nor is there any characterization of the function for which the network might provide some approximate proximal. This not only makes guaranteeing convergence of iterative schemes challenging but, more fundamentally, complicates the analysis of what has been learned by these networks about their training data. Herein we provide a framework to develop learned proximal networks (LPN), prove that they provide exact proximal operators for a data-driven nonconvex regularizer, and show how a new training strategy, dubbed proximal matching, provably promotes the recovery of the log-prior of the true data distribution. Such LPN provide general, unsupervised, expressive proximal operators that can be used for general inverse problems with convergence guarantees. We illustrate our results in a series of cases of increasing complexity, demonstrating that these models not only result in state-of-the-art performance, but provide a window into the resulting priors learned from data.
Paper Structure (56 sections, 19 theorems, 97 equations, 8 figures, 4 tables, 4 algorithms)

This paper contains 56 sections, 19 theorems, 97 equations, 8 figures, 4 tables, 4 algorithms.

Key Result

Proposition 2.1

[Characterization of continuous proximal operators, gribonval2020characterization] Let $\mathcal{Y} \subset \mathbb{R}^n$ be non-empty and open and $f : \mathcal{Y} \rightarrow \mathbb{R}^n$ be a continuous function. Then, $f$ is a proximal operator of a function $R :\mathbb{R}^n \rightarrow \mathbb

Figures (8)

  • Figure 1: Sketch of Prop. \ref{['prop:characterization-continuous-prox']} for $R(\cdot) = \|\cdot\|_1$.
  • Figure 2: The proximal $f_\theta$, convex potential $\psi_\theta$, and log-prior $R_\theta$ learned by LPN via the squared $\ell_2$ loss, $\ell_1$ loss, and proximal matching loss $\mathcal{L}_{PM}$ for a Laplacian distribution (ground truth in gray).
  • Figure 3: Left: log-prior $R_\theta$ learned by LPN on MNIST (computed over 100 test images), evaluated at images corrupted by (a) additive Gaussian noise, and (b) convex combination of two images $(1-\lambda) \mathbf{x} + \lambda \mathbf{x}'$. Right: the prior evaluated at individual examples.
  • Figure 4: Results on the Mayo-CT dataset (details in text).
  • Figure 5: Visual results for deblurring on CelebA using Plug-and-Play with different denoisers (BM3D, DnCNN, the gradient step (GS) Prox-DRUNet, and our LPN), for different Gaussian blur kernel standard deviation $\sigma_{blur}$ and noise standard deviation $\sigma_{noise}$. PSNR and SSIM are presented above each prediction.
  • ...and 3 more figures

Theorems & Definitions (36)

  • Proposition 2.1
  • Proposition 3.1: Learned Proximal Networks
  • Theorem 3.2: Learning via Proximal Matching
  • Theorem 4.1: Convergence guarantee for running PnP-ADMM with LPNs
  • Theorem B.1: Learning via Proximal Matching (Discrete Case)
  • Theorem B.2: Convergence guarantee for running PnP-PGD with LPNs
  • proof
  • proof
  • Remark : Other loss choices
  • proof
  • ...and 26 more