Training Large Language Models to Reason via EM Policy Gradient
Tianbing Xu
TL;DR
The paper tackles the challenge of improving large language model reasoning by learning to generate latent reasoning trajectories. It proposes EM Policy Gradient, an off-policy EM-style RL method that alternates sampling diverse latent rationales and reward-guided fine-tuning, avoiding the complexity of importance weights and clipping. On GSM8K and MATH HARD with QWen-2.5 base models, EMPG matches or slightly surpasses the state-of-the-art GRPO while offering advantages in simplicity, scalability, and concise reasoning, and it reveals emergent cognitive behaviors such as subproblem decomposition and self-verification. Overall, the approach provides a scalable and robust framework for enhancing LLM reasoning with interpretable internal reasoning processes.
Abstract
Recently, foundation models such as OpenAI's O1 and O3, along with DeepSeek's R1, have demonstrated strong reasoning capacities and problem-solving skills acquired through large-scale reinforcement learning (RL), with wide applications in mathematics, coding, science, intelligent agents, and virtual assistants. In this work, we introduce an off-policy reinforcement learning algorithm, EM Policy Gradient, aimed at enhancing LLM reasoning by optimizing expected return over reasoning trajectories. We frame the reasoning task as an Expectation-Maximization (EM) optimization problem, alternating between sampling diverse rationale trajectories and performing reward-guided fine-tuning. Unlike PPO and GRPO, which rely on complex importance weights and heuristic clipping, our method provides a simpler, more principled off-policy policy gradient approach, eliminating these complexities while maintaining strong performance. We evaluate the effectiveness of EM Policy Gradient on the GSM8K and MATH (HARD) datasets, where it achieves performance comparable to or slightly surpassing the state-of-the-art GRPO, while offering additional advantages in scalability, simplicity, and reasoning conciseness. Moreover, models fine-tuned with our method exhibit cognitive behaviors, such as sub-problem decomposition, self-verification, and backtracking, highlighting its potential to enhance both the interpretability and robustness of LLM reasoning.
