Table of Contents
Fetching ...

Adaptive teachers for amortized samplers

Minsu Kim, Sanghyeok Choi, Taeyoung Yun, Emmanuel Bengio, Leo Feng, Jarrid Rector-Brooks, Sungsoo Ahn, Jinkyoo Park, Nikolay Malkin, Yoshua Bengio

TL;DR

This work tackles exploration in amortized inference with generative flow networks (GFlowNets) by introducing an adaptive Teacher that samples high-loss regions to guide the Student toward underexplored modes. The Teacher uses TB-based loss signals to form a reward $R_{\text{Teacher}}(x)$, with a weighted emphasis on undersampled regions and a mixing term with the Student’s reward, controlled by $\alpha$. Through joint Teacher–Student training, a fixed backward policy, and local search to stabilize nonstationarity, the approach achieves improved mode coverage and sample efficiency across discrete and continuous domains including deceptive grids, diffusion-based sampling, and biochemical discovery, outperforming TB, $\epsilon$-exploration, GAFN, PER, and PRT baselines. The results demonstrate enhanced multimodal coverage and faster convergence, indicating broad practical impact for scalable amortized inference in complex, multimodal distributions.

Abstract

Amortized inference is the task of training a parametric model, such as a neural network, to approximate a distribution with a given unnormalized density where exact sampling is intractable. When sampling is implemented as a sequential decision-making process, reinforcement learning (RL) methods, such as generative flow networks, can be used to train the sampling policy. Off-policy RL training facilitates the discovery of diverse, high-reward candidates, but existing methods still face challenges in efficient exploration. We propose to use an adaptive training distribution (the \teacher) to guide the training of the primary amortized sampler (the \student). The \teacher, an auxiliary behavior model, is trained to sample high-loss regions of the \student and can generalize across unexplored modes, thereby enhancing mode coverage by providing an efficient training curriculum. We validate the effectiveness of this approach in a synthetic environment designed to present an exploration challenge, two diffusion-based sampling tasks, and four biochemical discovery tasks demonstrating its ability to improve sample efficiency and mode coverage. Source code is available at https://github.com/alstn12088/adaptive-teacher.

Adaptive teachers for amortized samplers

TL;DR

This work tackles exploration in amortized inference with generative flow networks (GFlowNets) by introducing an adaptive Teacher that samples high-loss regions to guide the Student toward underexplored modes. The Teacher uses TB-based loss signals to form a reward , with a weighted emphasis on undersampled regions and a mixing term with the Student’s reward, controlled by . Through joint Teacher–Student training, a fixed backward policy, and local search to stabilize nonstationarity, the approach achieves improved mode coverage and sample efficiency across discrete and continuous domains including deceptive grids, diffusion-based sampling, and biochemical discovery, outperforming TB, -exploration, GAFN, PER, and PRT baselines. The results demonstrate enhanced multimodal coverage and faster convergence, indicating broad practical impact for scalable amortized inference in complex, multimodal distributions.

Abstract

Amortized inference is the task of training a parametric model, such as a neural network, to approximate a distribution with a given unnormalized density where exact sampling is intractable. When sampling is implemented as a sequential decision-making process, reinforcement learning (RL) methods, such as generative flow networks, can be used to train the sampling policy. Off-policy RL training facilitates the discovery of diverse, high-reward candidates, but existing methods still face challenges in efficient exploration. We propose to use an adaptive training distribution (the \teacher) to guide the training of the primary amortized sampler (the \student). The \teacher, an auxiliary behavior model, is trained to sample high-loss regions of the \student and can generalize across unexplored modes, thereby enhancing mode coverage by providing an efficient training curriculum. We validate the effectiveness of this approach in a synthetic environment designed to present an exploration challenge, two diffusion-based sampling tasks, and four biochemical discovery tasks demonstrating its ability to improve sample efficiency and mode coverage. Source code is available at https://github.com/alstn12088/adaptive-teacher.
Paper Structure (57 sections, 1 theorem, 20 equations, 14 figures, 10 tables, 1 algorithm)

This paper contains 57 sections, 1 theorem, 20 equations, 14 figures, 10 tables, 1 algorithm.

Key Result

Proposition 1

Let the behavior policy $P_{\beta}(\tau)$ be a distribution over trajectories $\tau \in \mathcal{T}$ that satisfies full support. If the parameters $\theta^*$ and $\phi^*$ of the Student and Teacher policies, respectively, jointly optimize the objective functions to 0 in expectation over $P_{\beta}( where:

Figures (14)

  • Figure 1: Training an amortized sampler (Student) with an adaptive Teacher. Left: The behavior policy mixes Student, Teacher, and replay buffer policies to generate trajectories that train Student and store experiences. Teacher is updated based on Student's loss. Right: Student and Teacher distributions co-evolve, with Teacher targeting uncovered modes until Student converges to the target distribution.
  • Figure 2: Empirical distribution plots of $10^5$ test samples from policies on the $(d=2, H=256)$ grid.
  • Figure 3: Samples from trained models on the Manywell task (projected onto the first two dimensions).
  • Figure 4: KDE plots for 25GMM (left three) and Manywell (right three) at intermediate states of training. The Student (ratio) indicates the fraction of total training steps completed. The Teacher adaptively adjusts the training distribution in response to the modes that the Student is missing.
  • Figure 5: Training graphs for molecule design (QM9, sEH) and biological sequence design (TFbind8, L14-RNA1) tasks. Mean and standard deviation over five runs are shown.
  • ...and 9 more figures

Theorems & Definitions (2)

  • Proposition 1
  • proof