Table of Contents
Fetching ...

Masked Auto-Regressive Variational Acceleration: Fast Inference Makes Practical Reinforcement Learning

Yuxuan Gu, Weimin Bai, Yifei Wang, Weijian Luo, He Sun

TL;DR

MARVAL demonstrates the first practical path to distillation and RL of masked auto-regressive diffusion models, enabling fast sampling and better preference alignments, resulting in scalable yet human-preferred fast generative models.

Abstract

Masked auto-regressive diffusion models (MAR) benefit from the expressive modeling ability of diffusion models and the flexibility of masked auto-regressive ordering. However, vanilla MAR suffers from slow inference due to its hierarchical inference mechanism: an outer AR unmasking loop and an inner diffusion denoising chain. Such decoupled structure not only harm the generation efficiency but also hinder the practical use of MAR for reinforcement learning (RL), an increasingly critical paradigm for generative model post-training.To address this fundamental issue, we introduce MARVAL (Masked Auto-regressive Variational Acceleration), a distillation-based framework that compresses the diffusion chain into a single AR generation step while preserving the flexible auto-regressive unmasking order. Such a distillation with MARVAL not only yields substantial inference acceleration but, crucially, makes RL post-training with verifiable rewards practical, resulting in scalable yet human-preferred fast generative models. Our contributions are twofold: (1) a novel score-based variational objective for distilling masked auto-regressive diffusion models into a single generation step without sacrificing sample quality; and (2) an efficient RL framework for masked auto-regressive models via MARVAL-RL. On ImageNet 256*256, MARVAL-Huge achieves an FID of 2.00 with more than 30 times speedup compared with MAR-diffusion, and MARVAL-RL yields consistent improvements in CLIP and image-reward scores on ImageNet datasets with entity names. In conclusion, MARVAL demonstrates the first practical path to distillation and RL of masked auto-regressive diffusion models, enabling fast sampling and better preference alignments.

Masked Auto-Regressive Variational Acceleration: Fast Inference Makes Practical Reinforcement Learning

TL;DR

MARVAL demonstrates the first practical path to distillation and RL of masked auto-regressive diffusion models, enabling fast sampling and better preference alignments, resulting in scalable yet human-preferred fast generative models.

Abstract

Masked auto-regressive diffusion models (MAR) benefit from the expressive modeling ability of diffusion models and the flexibility of masked auto-regressive ordering. However, vanilla MAR suffers from slow inference due to its hierarchical inference mechanism: an outer AR unmasking loop and an inner diffusion denoising chain. Such decoupled structure not only harm the generation efficiency but also hinder the practical use of MAR for reinforcement learning (RL), an increasingly critical paradigm for generative model post-training.To address this fundamental issue, we introduce MARVAL (Masked Auto-regressive Variational Acceleration), a distillation-based framework that compresses the diffusion chain into a single AR generation step while preserving the flexible auto-regressive unmasking order. Such a distillation with MARVAL not only yields substantial inference acceleration but, crucially, makes RL post-training with verifiable rewards practical, resulting in scalable yet human-preferred fast generative models. Our contributions are twofold: (1) a novel score-based variational objective for distilling masked auto-regressive diffusion models into a single generation step without sacrificing sample quality; and (2) an efficient RL framework for masked auto-regressive models via MARVAL-RL. On ImageNet 256*256, MARVAL-Huge achieves an FID of 2.00 with more than 30 times speedup compared with MAR-diffusion, and MARVAL-RL yields consistent improvements in CLIP and image-reward scores on ImageNet datasets with entity names. In conclusion, MARVAL demonstrates the first practical path to distillation and RL of masked auto-regressive diffusion models, enabling fast sampling and better preference alignments.

Paper Structure

This paper contains 25 sections, 1 theorem, 20 equations, 6 figures, 6 tables.

Key Result

Theorem 1

The symbol $sg[\cdot]$ means stop gradient of some parameters. If distribution $q_\theta(\boldsymbol{x})$ satisfies some mild regularity conditions, we have the theorem for any score function $s_{p_t}(\cdot)$: Please refer to supplementary materials for the proof.

Figures (6)

  • Figure 1: Performance and qualitative results of MARVAL-RL. Top-left: Comparison of image quality between the MAR-B model (100 diffusion steps) and the MARVAL-RL-B model (1 diffusion step). MARVAL-RL significantly surpasses the MAR-B model in semantic quality, fidelity, and clarity. Bottom-left: Comparing FID-50k (y-axis) and inference time generating one image(x-axis) against other state-of-the-art methods. The MARVAL series (red dots) demonstrates superior performance, achieving low FID scores with significantly faster inference speeds (e.g., MARVAL-H achieves a FID of 2.00, and MARVAL-B is 32.95x faster than MAR-B). Right: A collection of diverse, high-quality images generated by MARVAL-RL-B model, showcasing its strong generative capabilities at an average speed of only 0.61 seconds per image.
  • Figure 1: Qualitative results of MARVAL-RL-L.
  • Figure 2: Illustration of our overall framework. (Top-left) The MAR inference process consists of an outer auto-regressive (AR) loop and an inner diffusion chain. Starting from the class embedding token $c$, MAR performs multiple AR iterations, where each iteration predicts a new subset of latent tokens through a short diffusion process. (Bottom-left) The student one-step generator $g_\theta$ and the auxiliary network are optimized alternately. In this stage, a portion of tokens is masked, and $g_\theta$ performs a single AR iteration to predict all masked tokens guided by the teacher MAR model's CFG-based predictions. (Right) The RL refinement stage further improves perceptual fidelity. Here, the distilled generator $g_\theta$ generates images through multi-step AR inference, and a reward model evaluates the outputs based on textual prompts. The reward loss then fine-tunes $g_\theta$ to better align with human perceptual preferences.
  • Figure 2: Qualitative results of MARVAL-RL-H.
  • Figure 3: Effect of AR iterations and diffusion steps on FID/IS for the MAR base model compared to MARVAL base model.
  • ...and 1 more figures

Theorems & Definitions (1)

  • Theorem : Gradient Equivalent Theorem