Table of Contents
Fetching ...

A-3PO: Accelerating Asynchronous LLM Training with Staleness-aware Proximal Policy Approximation

Xiaocan Li, Shiliang Wu, Zheng Shen

TL;DR

The paper tackles instability and low throughput in asynchronous RL for LLMs caused by data staleness. It replaces the explicit proximal policy computation in decoupled loss with a simple, staleness-aware log-space interpolation between behavior and target policies, avoiding an extra forward pass. Empirical results show substantial end-to-end training speedups (up to ~22% in the reported setup, and ~8,500x faster proximal evaluations) while maintaining comparable performance and improved stability metrics. This approach broadens the practicality of decoupled policy optimization for large-scale autoregressive models and is implemented in open-source code for reuse in LLM post-training pipelines.

Abstract

Decoupled loss has been a successful reinforcement learning (RL) algorithm to deal with the high data staleness under the asynchronous RL setting. Decoupled loss improves coupled-loss style of algorithms' (e.g., PPO, GRPO) learning stability by introducing a proximal policy to decouple the off-policy corrections (importance weight) from the controlling policy updates (trust region). However, the proximal policy requires an extra forward pass through the network at each training step, creating a computational bottleneck for large language models. We observe that since the proximal policy only serves as a trust region anchor between the behavior and target policies, we can approximate it through simple interpolation without explicit computation. We call this approach A-3PO (APproximated Proximal Policy Optimization). A-3PO eliminates this overhead, reducing training time by 18% while maintaining comparable performance. Code & off-the-shelf example are available at: https://github.com/inclusionAI/AReaL/blob/main/docs/algorithms/prox_approx.md

A-3PO: Accelerating Asynchronous LLM Training with Staleness-aware Proximal Policy Approximation

TL;DR

The paper tackles instability and low throughput in asynchronous RL for LLMs caused by data staleness. It replaces the explicit proximal policy computation in decoupled loss with a simple, staleness-aware log-space interpolation between behavior and target policies, avoiding an extra forward pass. Empirical results show substantial end-to-end training speedups (up to ~22% in the reported setup, and ~8,500x faster proximal evaluations) while maintaining comparable performance and improved stability metrics. This approach broadens the practicality of decoupled policy optimization for large-scale autoregressive models and is implemented in open-source code for reuse in LLM post-training pipelines.

Abstract

Decoupled loss has been a successful reinforcement learning (RL) algorithm to deal with the high data staleness under the asynchronous RL setting. Decoupled loss improves coupled-loss style of algorithms' (e.g., PPO, GRPO) learning stability by introducing a proximal policy to decouple the off-policy corrections (importance weight) from the controlling policy updates (trust region). However, the proximal policy requires an extra forward pass through the network at each training step, creating a computational bottleneck for large language models. We observe that since the proximal policy only serves as a trust region anchor between the behavior and target policies, we can approximate it through simple interpolation without explicit computation. We call this approach A-3PO (APproximated Proximal Policy Optimization). A-3PO eliminates this overhead, reducing training time by 18% while maintaining comparable performance. Code & off-the-shelf example are available at: https://github.com/inclusionAI/AReaL/blob/main/docs/algorithms/prox_approx.md

Paper Structure

This paper contains 12 sections, 4 equations, 5 figures.

Figures (5)

  • Figure 1: Comparison of log probability computation time between loglinear approximation and full recomputing methods. The loglinear method eliminates the computational overhead of proximal policy evaluation, achieving near-instantaneous computation compared to the 10-second forward pass required by recomputing.
  • Figure 2: Training progress measured by average task reward over wall-clock time. The loglinear approximation achieves 27% faster training (2.72 vs 3.46 hours) while maintaining comparable final performance, demonstrating the practical benefit of eliminating proximal policy computation.
  • Figure 3: Policy entropy over training steps. Both methods show healthy entropy decay, with loglinear maintaining slightly higher values that may benefit exploration.
  • Figure 4: Importance weight statistics during training. Left: Maximum importance weights. Right: Minimum importance weights. The loglinear approximation exhibits more controlled importance weights, suggesting better stability under off-policy conditions.
  • Figure 5: Number of tokens clipped per training step. The recomputing method clips 6$\times$ more tokens on average, suggesting larger policy updates that require more constraint enforcement, while loglinear updates remain naturally within trust region bounds, indicating higher sample-efficiency.