Generalizing Stochastic Smoothing for Differentiation and Gradient Estimation
Felix Petersen, Christian Borgelt, Aashwin Mishra, Stefano Ermon
TL;DR
Problem: gradient estimation for stochastic relaxations of non-differentiable black-box functions. Approach: generalized stochastic smoothing that relaxes inputs via a density $\mu$ to form $f_\u0005epsilon(x)=\mathbb{E}_{\u0005epsilon\sim\mu}[f(x+\u0005epsilon)]$, with unbiased gradient estimators such as $\nabla_{x} f_\u0005epsilon(x)=\mathbb{E}_{\u0005epsilon\sim\mu}[f(x+\u0005epsilon)\nabla_{\u0005epsilon}(-\log\mu(\u0005epsilon))]$, extended to vector-valued outputs and anisotropic scale matrices $\mathbf{L}$; and variance-reduction techniques. Key contributions include relaxing assumptions on $\mu$ (including non-differentiable and compact-support densities like Laplace and Triangular), a $k$-sample median extension, and a clear algorithm-vs-loss smoothing distinction, with broad empirical validation. Significance: enables differentiating a wide class of non-differentiable black-box components (sorting, shortest-paths, rendering, cryo-ET) with controllable variance, and guides practical choices of distributions and variance-reduction strategies for improved performance.
Abstract
We deal with the problem of gradient estimation for stochastic differentiable relaxations of algorithms, operators, simulators, and other non-differentiable functions. Stochastic smoothing conventionally perturbs the input of a non-differentiable function with a differentiable density distribution with full support, smoothing it and enabling gradient estimation. Our theory starts at first principles to derive stochastic smoothing with reduced assumptions, without requiring a differentiable density nor full support, and we present a general framework for relaxation and gradient estimation of non-differentiable black-box functions $f:\mathbb{R}^n\to\mathbb{R}^m$. We develop variance reduction for gradient estimation from 3 orthogonal perspectives. Empirically, we benchmark 6 distributions and up to 24 variance reduction strategies for differentiable sorting and ranking, differentiable shortest-paths on graphs, differentiable rendering for pose estimation, as well as differentiable cryo-ET simulations.
