Table of Contents
Fetching ...

Beyond Uniform Credit: Causal Credit Assignment for Policy Optimization

Mykola Khandoga, Rui Yuan, Vinay Kumar Sankarapu

TL;DR

The work tackles the challenge of gradient dilution in policy-gradient training for language-model reasoning by introducing counterfactual importance weighting. It masks reasoning spans and uses the drop in answer probability to assign higher credit to causally critical steps, integrating these weights into a DAPO-based objective. Across GSM8K, the method yields consistent gains (+0.8 to +1.1 percentage points) over uniform baselines and accelerates learning, while inverted weighting degrades performance, validating the causal signal. Analyses reveal that calculation chains are disproportionately important, with gradient mass concentrated on high-importance spans, though the approach incurs notable computational overhead and shows limited transfer to code-generation tasks, indicating domain-specific applicability and room for future refinement.

Abstract

Policy gradient methods for language model reasoning, such as GRPO and DAPO, assign uniform credit to all generated tokens - the filler phrase "Let me think" receives the same gradient update as the critical calculation "23 + 45 = 68." We propose counterfactual importance weighting: mask reasoning spans, measure the drop in answer probability, and upweight tokens accordingly during policy gradient updates. Our method requires no auxiliary models or external annotation, instead importance is estimated directly from the policy model's own probability shifts. Experiments on GSM8K across three models spanning the Qwen and Llama families demonstrate consistent improvements over uniform baselines and faster convergence to equivalent accuracy. Inverting the importance signal hurts performance, confirming we capture genuine causal structure rather than noise. Analysis shows the method correctly prioritizes calculation steps over scaffolding text. We view these findings as establishing counterfactual importance weighting as a foundation for further research rather than a complete solution.

Beyond Uniform Credit: Causal Credit Assignment for Policy Optimization

TL;DR

The work tackles the challenge of gradient dilution in policy-gradient training for language-model reasoning by introducing counterfactual importance weighting. It masks reasoning spans and uses the drop in answer probability to assign higher credit to causally critical steps, integrating these weights into a DAPO-based objective. Across GSM8K, the method yields consistent gains (+0.8 to +1.1 percentage points) over uniform baselines and accelerates learning, while inverted weighting degrades performance, validating the causal signal. Analyses reveal that calculation chains are disproportionately important, with gradient mass concentrated on high-importance spans, though the approach incurs notable computational overhead and shows limited transfer to code-generation tasks, indicating domain-specific applicability and room for future refinement.

Abstract

Policy gradient methods for language model reasoning, such as GRPO and DAPO, assign uniform credit to all generated tokens - the filler phrase "Let me think" receives the same gradient update as the critical calculation "23 + 45 = 68." We propose counterfactual importance weighting: mask reasoning spans, measure the drop in answer probability, and upweight tokens accordingly during policy gradient updates. Our method requires no auxiliary models or external annotation, instead importance is estimated directly from the policy model's own probability shifts. Experiments on GSM8K across three models spanning the Qwen and Llama families demonstrate consistent improvements over uniform baselines and faster convergence to equivalent accuracy. Inverting the importance signal hurts performance, confirming we capture genuine causal structure rather than noise. Analysis shows the method correctly prioritizes calculation steps over scaffolding text. We view these findings as establishing counterfactual importance weighting as a foundation for further research rather than a complete solution.
Paper Structure (63 sections, 8 equations, 1 figure, 8 tables)

This paper contains 63 sections, 8 equations, 1 figure, 8 tables.

Figures (1)

  • Figure 1: Training curves on GSM8K. Counterfactual weighting (red) consistently outperforms vanilla DAPO (blue) throughout training. Inverted weighting (purple) underperforms, validating that importance direction matters. Shaded regions show $\pm 1$ std across seeds.