Table of Contents
Fetching ...

Designing Instance-Level Sampling Schedules via REINFORCE with James-Stein Shrinkage

Peiyu Yu, Suraj Kothawade, Sirui Xie, Ying Nian Wu, Hongliang Fei

TL;DR

This work introduces instance-level rescheduling for frozen text-to-image samplers by learning a single-shot, prompt- and seed-conditioned scheduling policy implemented as a Dirichlet distribution. To stabilize high-variance policy gradients in this high-dimensional setting, it derives a principled James-Stein shrinkage baseline that adaptively combines per-context and cross-context information, yielding lower estimator variance and better learning efficiency. Empirically, learned schedules improve text–image alignment, text rendering, and fine-grained control across multiple backbones and budgets, including competitive performance at only 5 steps without distillation. The approach offers a model-agnostic post-training lever that unlocks additional generative potential without altering the pretrained backbone, with demonstrated impact on few-step generation and downstream alignment tasks.

Abstract

Most post-training methods for text-to-image samplers focus on model weights: either fine-tuning the backbone for alignment or distilling it for few-step efficiency. We take a different route: rescheduling the sampling timeline of a frozen sampler. Instead of a fixed, global schedule, we learn instance-level (prompt- and noise-conditioned) schedules through a single-pass Dirichlet policy. To ensure accurate gradient estimates in high-dimensional policy learning, we introduce a novel reward baseline based on a principled James-Stein estimator; it provably achieves lower estimation errors than commonly used variants and leads to superior performance. Our rescheduled samplers consistently improve text-image alignment including text rendering and compositional control across modern Stable Diffusion and Flux model families. Additionally, a 5-step Flux-Dev sampler with our schedules can attain generation quality comparable to deliberately distilled samplers like Flux-Schnell. We thus position our scheduling framework as an emerging model-agnostic post-training lever that unlocks additional generative potential in pretrained samplers.

Designing Instance-Level Sampling Schedules via REINFORCE with James-Stein Shrinkage

TL;DR

This work introduces instance-level rescheduling for frozen text-to-image samplers by learning a single-shot, prompt- and seed-conditioned scheduling policy implemented as a Dirichlet distribution. To stabilize high-variance policy gradients in this high-dimensional setting, it derives a principled James-Stein shrinkage baseline that adaptively combines per-context and cross-context information, yielding lower estimator variance and better learning efficiency. Empirically, learned schedules improve text–image alignment, text rendering, and fine-grained control across multiple backbones and budgets, including competitive performance at only 5 steps without distillation. The approach offers a model-agnostic post-training lever that unlocks additional generative potential without altering the pretrained backbone, with demonstrated impact on few-step generation and downstream alignment tasks.

Abstract

Most post-training methods for text-to-image samplers focus on model weights: either fine-tuning the backbone for alignment or distilling it for few-step efficiency. We take a different route: rescheduling the sampling timeline of a frozen sampler. Instead of a fixed, global schedule, we learn instance-level (prompt- and noise-conditioned) schedules through a single-pass Dirichlet policy. To ensure accurate gradient estimates in high-dimensional policy learning, we introduce a novel reward baseline based on a principled James-Stein estimator; it provably achieves lower estimation errors than commonly used variants and leads to superior performance. Our rescheduled samplers consistently improve text-image alignment including text rendering and compositional control across modern Stable Diffusion and Flux model families. Additionally, a 5-step Flux-Dev sampler with our schedules can attain generation quality comparable to deliberately distilled samplers like Flux-Schnell. We thus position our scheduling framework as an emerging model-agnostic post-training lever that unlocks additional generative potential in pretrained samplers.

Paper Structure

This paper contains 48 sections, 2 theorems, 36 equations, 16 figures, 7 tables, 1 algorithm.

Key Result

Proposition 1

For the gradient estimator in eq:reinforce, the baseline that minimizes $\operatorname{Var}[(r - b)\nabla_\theta \log \pi_\theta(\tau)]$ is Sketch. Differentiate the variance w.r.t. $b$ and set the derivative to zero. See supplemental materials for a detailed proof.

Figures (16)

  • Figure 1: Instance-level schedules improve text-to-image generation.(a)-(d) illustrate four aspects where pretrained models like Flux-Dev benefit from our schedules. Samplers using our schedules (right) show consistent improvements over those with the default schedule (left), which are even more pronounced at only 5 inference steps (a). We visualize the schedules at the top left corner of each image; X-axis denotes the number of inference steps, and Y-axis denotes the actual timestep values. We plot the KDE of t values along the Y-axis.
  • Figure 2: James--Stein (JS) reward baseline.\ref{['fig:js_variance']} Simulation results from anlaytical policies showcasing JS is consistently better than RLOO for different number of rollouts $K$. More details in supplemental materials. \ref{['fig:js_diagram_panel']} Diagram of JS baseline: combining $b_{\mathrm{RLOO}}$ and $b_{\mathrm{xctx}}$ into $b_{\mathrm{JS}} = \alpha_c b_{\mathrm{RLOO}} + (1-\alpha_c) b_{\mathrm{xctx}}$ (Eq. \ref{['eq:js-baseline']}).
  • Figure 3: Rescheduling improves general T2I alignment. Head-to-head comparisons between images generated with default schedules (upper) and our learned schedules (lower) from Flux-Dev with 40 steps. Figures henceforth follow the same format.
  • Figure 4: HPSv2 scores on HPD2 held-out prompts. Scores vs. number of inference steps ($L$). Rows per backbone: Default, Cross-Context RLOO (XCTX), RLOO, TPDM-style PPO (Flux only), and Ours (JS). Best per column in bold; second-best underlined.
  • Figure 5: Representative training curves (Flux-Dev) for different baselines and sampling steps $L$. Y-axis denotes the aggregated rewards; X-axis denotes the num. of iterations. We observe that JS baseline consistently outperforms XCTX and RLOO baselines; the gap is clear at low budget and persists even when $L$ is large when the discretization error of sampling trajectory diminishes.
  • ...and 11 more figures

Theorems & Definitions (4)

  • Proposition 1: Variance–optimal baseline
  • Theorem 3.1: MSE improvement and empirical Bayes optimality
  • proof
  • proof