Table of Contents
Fetching ...

Generative Marginalization Models

Sulin Liu, Peter J. Ramadge, Ryan P. Adams

TL;DR

Generative Marginalization Models (MaMs) tackle efficient marginal inference for high-dimensional discrete data by explicitly modeling all induced marginals $p(oldx_\mathcal{S})$ under a marginalization self-consistency constraint. By using a dual-network setup that learns marginals $p_\theta(\boldx)$ and conditionals $p_\phi(\boldx|\cdot)$, and an augmented input representation with a missing-value symbol, MaMs enable estimating any marginal with a single forward pass, while supporting scalable MLE and EB training. The approach yields significant speedups in marginal evaluation and scales to any-order generation in EB contexts, outperforming ARMs and AO-ARMs on a range of discrete tasks including images, text, molecules, and physical systems. These results highlight MaMs' potential for flexible, domain-guided marginal queries, outlier detection, and design tasks in real-world discrete-data problems.

Abstract

We introduce marginalization models (MAMs), a new family of generative models for high-dimensional discrete data. They offer scalable and flexible generative modeling by explicitly modeling all induced marginal distributions. Marginalization models enable fast approximation of arbitrary marginal probabilities with a single forward pass of the neural network, which overcomes a major limitation of arbitrary marginal inference models, such as any-order autoregressive models. MAMs also address the scalability bottleneck encountered in training any-order generative models for high-dimensional problems under the context of energy-based training, where the goal is to match the learned distribution to a given desired probability (specified by an unnormalized log-probability function such as energy or reward function). We propose scalable methods for learning the marginals, grounded in the concept of "marginalization self-consistency". We demonstrate the effectiveness of the proposed model on a variety of discrete data distributions, including images, text, physical systems, and molecules, for maximum likelihood and energy-based training settings. MAMs achieve orders of magnitude speedup in evaluating the marginal probabilities on both settings. For energy-based training tasks, MAMs enable any-order generative modeling of high-dimensional problems beyond the scale of previous methods. Code is available at https://github.com/PrincetonLIPS/MaM.

Generative Marginalization Models

TL;DR

Generative Marginalization Models (MaMs) tackle efficient marginal inference for high-dimensional discrete data by explicitly modeling all induced marginals under a marginalization self-consistency constraint. By using a dual-network setup that learns marginals and conditionals , and an augmented input representation with a missing-value symbol, MaMs enable estimating any marginal with a single forward pass, while supporting scalable MLE and EB training. The approach yields significant speedups in marginal evaluation and scales to any-order generation in EB contexts, outperforming ARMs and AO-ARMs on a range of discrete tasks including images, text, molecules, and physical systems. These results highlight MaMs' potential for flexible, domain-guided marginal queries, outlier detection, and design tasks in real-world discrete-data problems.

Abstract

We introduce marginalization models (MAMs), a new family of generative models for high-dimensional discrete data. They offer scalable and flexible generative modeling by explicitly modeling all induced marginal distributions. Marginalization models enable fast approximation of arbitrary marginal probabilities with a single forward pass of the neural network, which overcomes a major limitation of arbitrary marginal inference models, such as any-order autoregressive models. MAMs also address the scalability bottleneck encountered in training any-order generative models for high-dimensional problems under the context of energy-based training, where the goal is to match the learned distribution to a given desired probability (specified by an unnormalized log-probability function such as energy or reward function). We propose scalable methods for learning the marginals, grounded in the concept of "marginalization self-consistency". We demonstrate the effectiveness of the proposed model on a variety of discrete data distributions, including images, text, physical systems, and molecules, for maximum likelihood and energy-based training settings. MAMs achieve orders of magnitude speedup in evaluating the marginal probabilities on both settings. For energy-based training tasks, MAMs enable any-order generative modeling of high-dimensional problems beyond the scale of previous methods. Code is available at https://github.com/PrincetonLIPS/MaM.
Paper Structure (44 sections, 1 theorem, 23 equations, 35 figures, 11 tables, 2 algorithms)

This paper contains 44 sections, 1 theorem, 23 equations, 35 figures, 11 tables, 2 algorithms.

Key Result

Proposition 1

Solving the optimization problem in eq:mar_ml_constr is equivalent to the following two-stage optimization procedure, under mild assumptions about the neural networks used being universal approximators: where $\sigma \sim \mcU(S_D)$, $d \sim \mcU(1,\cdots,D)$ and $q(\boldx)$ is the distribution of interest for marginal likelihood inference.

Figures (35)

  • Figure 1: Training and test time scalability of sequential discrete generative models. The unit is number of function (i.e. NN) evaluations (NFE).
  • Figure 2: Marginalization models (MaMs) enable estimation of any marginal probability with a neural network $\theta$ that learns to "marginalize out" variables (represented by "$?$"). The figure illustrates marginalization of a single variable on bit strings (representing molecules) with two alternatives for clarity (versus $K$ in general). The bars represent probability masses.
  • Figure 3: Approximating $\log p_\phi(\boldx)$ with one-step conditional (ARM-MC) results in extremely high gradient variance in energy-based training.
  • Figure 4: An example of the marginal estimates of an ImageNet32 image along the generation trajectory using a random ordering. The numbers in the captions show that the learned (log) marginals (left) v.s. learned (log) conditionals (right) are approximately self-consistent.
  • Figure 5: Ising model. Left: $D = 10 \times 10$. Right: $D = 30 \times 30$
  • ...and 30 more figures

Theorems & Definitions (3)

  • Proposition 1
  • proof
  • Remark 1