Table of Contents
Fetching ...

Gradient Estimation Using Stochastic Computation Graphs

John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

TL;DR

The paper introduces stochastic computation graphs, a directed acyclic graph framework combining deterministic operations and stochastic nodes to model loss functions defined as expectations over random variables. It derives unbiased gradient estimators via a surrogate loss that can be differentiated with standard backpropagation, unifying and extending gradient estimators from variational inference and reinforcement learning. It also presents variance-reduction techniques using baselines and discusses higher-order derivatives and practical algorithms for implementation with automatic differentiation. The framework generalizes prior methods, enabling efficient gradient computation for complex models with attention, memory, and control components, and provides a practical route to sophisticated stochastic-deterministic architectures.

Abstract

In a variety of problems originating in supervised, unsupervised, and reinforcement learning, the loss function is defined by an expectation over a collection of random variables, which might be part of a probabilistic model or the external world. Estimating the gradient of this loss function, using samples, lies at the core of gradient-based learning algorithms for these problems. We introduce the formalism of stochastic computation graphs---directed acyclic graphs that include both deterministic functions and conditional probability distributions---and describe how to easily and automatically derive an unbiased estimator of the loss function's gradient. The resulting algorithm for computing the gradient estimator is a simple modification of the standard backpropagation algorithm. The generic scheme we propose unifies estimators derived in variety of prior work, along with variance-reduction techniques therein. It could assist researchers in developing intricate models involving a combination of stochastic and deterministic operations, enabling, for example, attention, memory, and control actions.

Gradient Estimation Using Stochastic Computation Graphs

TL;DR

The paper introduces stochastic computation graphs, a directed acyclic graph framework combining deterministic operations and stochastic nodes to model loss functions defined as expectations over random variables. It derives unbiased gradient estimators via a surrogate loss that can be differentiated with standard backpropagation, unifying and extending gradient estimators from variational inference and reinforcement learning. It also presents variance-reduction techniques using baselines and discusses higher-order derivatives and practical algorithms for implementation with automatic differentiation. The framework generalizes prior methods, enabling efficient gradient computation for complex models with attention, memory, and control components, and provides a practical route to sophisticated stochastic-deterministic architectures.

Abstract

In a variety of problems originating in supervised, unsupervised, and reinforcement learning, the loss function is defined by an expectation over a collection of random variables, which might be part of a probabilistic model or the external world. Estimating the gradient of this loss function, using samples, lies at the core of gradient-based learning algorithms for these problems. We introduce the formalism of stochastic computation graphs---directed acyclic graphs that include both deterministic functions and conditional probability distributions---and describe how to easily and automatically derive an unbiased estimator of the loss function's gradient. The resulting algorithm for computing the gradient estimator is a simple modification of the standard backpropagation algorithm. The generic scheme we propose unifies estimators derived in variety of prior work, along with variance-reduction techniques therein. It could assist researchers in developing intricate models involving a combination of stochastic and deterministic operations, enabling, for example, attention, memory, and control actions.

Paper Structure

This paper contains 24 sections, 1 theorem, 17 equations, 7 figures, 1 algorithm.

Key Result

Corollary 1

Let $L(\Theta,\mathcal{S}) \vcentcolon= \sum_{w} \log p(w \:\vert\:\textsc{deps}_{w}) \hat{Q}_{w} + \sum_{c \in \mathcal{C}} c(\textsc{deps}_{c}).$ Then differentiation of $L$ gives us an unbiased gradient estimate: $\frac{\partial}{\partial \theta} \mathbb{E}\left[ \sum_{c \in \mathcal{C}} c\right]

Figures (7)

  • Figure 1: Simple stochastic computation graphs
  • Figure 2: Deterministic computation graphs obtained as surrogate loss functions of stochastic computation graphs from \ref{['fig:simplescg']}.
  • Figure :
  • Figure :
  • Figure :
  • ...and 2 more figures

Theorems & Definitions (1)

  • Corollary 1