BranchGRPO: Stable and Efficient GRPO with Structured Branching in Diffusion Models
Yuming Li, Yikai Wang, Yuying Zhu, Zhongyu Zhao, Ming Lu, Qi She, Shanghang Zhang
TL;DR
This work tackles inefficiency and unstable credit assignment in GRPO for diffusion-based image and video generation. It introduces BranchGRPO, a tree-structured rollout with shared prefixes, dense step-level rewards via reward fusion and depth-wise normalization, and pruning strategies to reduce backpropagation cost. The main contributions are the branching rollout framework, the reward fusion mechanism, and the pruning schemes, demonstrated to yield faster convergence and higher alignment on HPDv2.1 and WanX. Scaling experiments show that larger branch sizes improve performance, and video results indicate improved temporal coherence and sharper frames.
Abstract
Recent progress in aligning image and video generative models with Group Relative Policy Optimization (GRPO) has improved human preference alignment, but existing variants remain inefficient due to sequential rollouts and large numbers of sampling steps, unreliable credit assignment: sparse terminal rewards are uniformly propagated across timesteps, failing to capture the varying criticality of decisions during denoising. In this paper, we present BranchGRPO, a method that restructures the rollout process into a branching tree, where shared prefixes amortize computation and pruning removes low-value paths and redundant depths. BranchGRPO introduces three contributions: (1) a branching scheme that amortizes rollout cost through shared prefixes while preserving exploration diversity; (2) a reward fusion and depth-wise advantage estimator that transforms sparse terminal rewards into dense step-level signals; and (3) pruning strategies that cut gradient computation but leave forward rollouts and exploration unaffected. On HPDv2.1 image alignment, BranchGRPO improves alignment scores by up to \textbf{16\%} over DanceGRPO, while reducing per-iteration training time by nearly \textbf{55\%}. A hybrid variant, BranchGRPO-Mix, further accelerates training to 4.7x faster than DanceGRPO without degrading alignment. On WanX video generation, it further achieves higher Video-Align scores with sharper and temporally consistent frames compared to DanceGRPO. Codes are available at \href{https://fredreic1849.github.io/BranchGRPO-Webpage/}{BranchGRPO}.
