Table of Contents
Fetching ...

Discriminator Guidance for Autoregressive Diffusion Models

Filip Ekström Kelvinius, Fredrik Lindsten

TL;DR

The paper addresses generating discrete data with autoregressive diffusion by introducing discriminator guidance to ARDMs. It develops ARDG to directly correct intermediate conditionals using a discriminator, and introduces two Sequential Monte Carlo variants (BSDG and FADG) to mitigate errors from imperfect discriminators and enable parallel sampling. With perfect discriminators, ARDG yields exact sampling from the data distribution; with sub-optimal discriminators, BSDG and FADG provide robust, higher-quality samples, demonstrated on molecular graph generation tasks. The approach improves key generation metrics on QM9 and MOSES while offering a controllable compute-quality trade-off via the number of SMC particles and the generation order strategy, enhancing practical applicability for graph-based generative modeling.

Abstract

We introduce discriminator guidance in the setting of Autoregressive Diffusion Models. The use of a discriminator to guide a diffusion process has previously been used for continuous diffusion models, and in this work we derive ways of using a discriminator together with a pretrained generative model in the discrete case. First, we show that using an optimal discriminator will correct the pretrained model and enable exact sampling from the underlying data distribution. Second, to account for the realistic scenario of using a sub-optimal discriminator, we derive a sequential Monte Carlo algorithm which iteratively takes the predictions from the discriminator into account during the generation process. We test these approaches on the task of generating molecular graphs and show how the discriminator improves the generative performance over using only the pretrained model.

Discriminator Guidance for Autoregressive Diffusion Models

TL;DR

The paper addresses generating discrete data with autoregressive diffusion by introducing discriminator guidance to ARDMs. It develops ARDG to directly correct intermediate conditionals using a discriminator, and introduces two Sequential Monte Carlo variants (BSDG and FADG) to mitigate errors from imperfect discriminators and enable parallel sampling. With perfect discriminators, ARDG yields exact sampling from the data distribution; with sub-optimal discriminators, BSDG and FADG provide robust, higher-quality samples, demonstrated on molecular graph generation tasks. The approach improves key generation metrics on QM9 and MOSES while offering a controllable compute-quality trade-off via the number of SMC particles and the generation order strategy, enhancing practical applicability for graph-based generative modeling.

Abstract

We introduce discriminator guidance in the setting of Autoregressive Diffusion Models. The use of a discriminator to guide a diffusion process has previously been used for continuous diffusion models, and in this work we derive ways of using a discriminator together with a pretrained generative model in the discrete case. First, we show that using an optimal discriminator will correct the pretrained model and enable exact sampling from the underlying data distribution. Second, to account for the realistic scenario of using a sub-optimal discriminator, we derive a sequential Monte Carlo algorithm which iteratively takes the predictions from the discriminator into account during the generation process. We test these approaches on the task of generating molecular graphs and show how the discriminator improves the generative performance over using only the pretrained model.
Paper Structure (33 sections, 39 equations, 1 figure, 8 tables, 3 algorithms)

This paper contains 33 sections, 39 equations, 1 figure, 8 tables, 3 algorithms.

Figures (1)

  • Figure 1: Illustration of the ARDG method when applied to graphs. The variables $x_t$ (nodes and edges in the graph) are one by one assigned values, where nodes and edges with "?" correspond to variables which have not yet been assigned. Assignment of the white node could be done by sampling from the conditional distribution $p_{\theta}(x_{\sigma(t)} |\mathbf{x}_{\sigma(<t)})$, which has been learnt by a neural network. In our method, however, a separate discriminator, $d_{\phi}$, has been trained to distinguish between real and fake (generated by the generative model $p_{\theta}$) samples. With the help of this discriminator, we can correct the distribution $p_{\theta}(x_{\sigma(t)} | \mathbf{x}_{\sigma(<t)})$ so that it becomes closer to the true underlying data distribution $p_{\rm{data}}(x_{\sigma(t)}|\mathbf{x}_{\sigma(<t)})$.