Table of Contents
Fetching ...

Trajectory Balance with Asynchrony: Decoupling Exploration and Learning for Fast, Scalable LLM Post-Training

Brian Bartoldson, Siddarth Venkatraman, James Diffenderfer, Moksh Jain, Tal Ben-Nun, Seanie Lee, Minsu Kim, Johan Obando-Ceron, Yoshua Bengio, Bhavya Kailkhura

TL;DR

This work tackles the inefficiency of on-policy RL for large language model post-training by introducing Trajectory Balance with Asynchrony (TBA), a distributed off-policy RL framework that decouples exploration from learning. By employing a trajectory balance objective with VarGrad and a Searcher-Trainer architecture, TBA leverages large replay buffers to learn from diverse, off-policy data in parallel, achieving substantial speedups (up to 50x) while maintaining or surpassing baseline performance across mathematical reasoning, preference-tuning, and automated red-teaming tasks. The paper also explores scalability via TBA', a simplified variant suitable for larger models, and analyzes the tradeoffs of off-policy data through the recency/reward sampling parameter m. Overall, TBA demonstrates robust, scalable, and efficient LLM post-training, enabling faster deployment and broader exploration for alignment and safety objectives.

Abstract

Reinforcement learning (RL) is a critical component of large language model (LLM) post-training. However, on-policy algorithms used for post-training are not naturally robust to a diversified content of experience replay buffers, which asynchronous off-policy actors can efficiently populate in parallel to training. We propose efficiently learning on such off-policy data via Trajectory Balance with Asynchrony (TBA), an approach to asynchronous RL for LLMs that leverages the principled off-policy TB objective. On math, preference-tuning, and automated red-teaming tasks, we post-train models ranging from Pythia 410M to Qwen 2.5 7B, finding TBA offers speed and performance boosts over strong baselines like Online DPO and Dr. GRPO. Beyond TBA's performance benefits (high accuracy even as asynchrony grows) and speedups ($4\times$ or more), we show its reward- and recency-prioritizing sampling enable further gains as data generation is scaled. Our code is available at https://github.com/bbartoldson/TBA.

Trajectory Balance with Asynchrony: Decoupling Exploration and Learning for Fast, Scalable LLM Post-Training

TL;DR

This work tackles the inefficiency of on-policy RL for large language model post-training by introducing Trajectory Balance with Asynchrony (TBA), a distributed off-policy RL framework that decouples exploration from learning. By employing a trajectory balance objective with VarGrad and a Searcher-Trainer architecture, TBA leverages large replay buffers to learn from diverse, off-policy data in parallel, achieving substantial speedups (up to 50x) while maintaining or surpassing baseline performance across mathematical reasoning, preference-tuning, and automated red-teaming tasks. The paper also explores scalability via TBA', a simplified variant suitable for larger models, and analyzes the tradeoffs of off-policy data through the recency/reward sampling parameter m. Overall, TBA demonstrates robust, scalable, and efficient LLM post-training, enabling faster deployment and broader exploration for alignment and safety objectives.

Abstract

Reinforcement learning (RL) is a critical component of large language model (LLM) post-training. However, on-policy algorithms used for post-training are not naturally robust to a diversified content of experience replay buffers, which asynchronous off-policy actors can efficiently populate in parallel to training. We propose efficiently learning on such off-policy data via Trajectory Balance with Asynchrony (TBA), an approach to asynchronous RL for LLMs that leverages the principled off-policy TB objective. On math, preference-tuning, and automated red-teaming tasks, we post-train models ranging from Pythia 410M to Qwen 2.5 7B, finding TBA offers speed and performance boosts over strong baselines like Online DPO and Dr. GRPO. Beyond TBA's performance benefits (high accuracy even as asynchrony grows) and speedups ( or more), we show its reward- and recency-prioritizing sampling enable further gains as data generation is scaled. Our code is available at https://github.com/bbartoldson/TBA.

Paper Structure

This paper contains 37 sections, 17 equations, 14 figures, 6 tables.

Figures (14)

  • Figure 1: TBA excels on the GSM8K mathematical reasoning task. All plotted points use 4xA100 GPUs (or comparable L40S GPUs). DPO and RLOO baselines taken from noukhovitch2024asynchronous, PPO and VinePPO baselines taken from kazemnejad2024vineppounlockingrlpotential. The baseline model is the SFTed RhoMath-1B lin2024rho model, which gets 40% accuracy after SFT and before RL. Appendix \ref{['app:hparams']} has details.
  • Figure 2: Fast, scalable LLM post-training with TBA. Continuously (solid lines), multiple Searcher nodes (left) collect trajectories, while a Trainer node (right) samples from a replay buffer to train the policy off-policy. Periodically (dashed lines), updated policy weights are sent to Searcher nodes, and new trajectories are added to the Trainer node's buffer. This avoids bottlenecks at any given node, which can be 1 or more GPUs, keeping resource utilization high.
  • Figure 3: TBA scales search and improves RL efficiency on the PFT summarization task. All plotted points use 4xA100 GPUs, but TBA allocates 3 GPUs to search, and Online DPO allocates 1 GPU to search. TBA produces large-scale off-policy data that its trajectory balance objective can leverage, creating massive efficiency benefits. Online DPO baselines taken from noukhovitch2024asynchronous. Dashed and solid lines use 256 and 425 updates, respectively. Appendix \ref{['app:hparams']} has details.
  • Figure 4: TBA defines a new KL vs. win-rate Pareto frontier on the PFT summarization task. The baseline "Online DPO" frontier is created by increasing the degree of off-policyness, starting from on-policy Online DPO, results from noukhovitch2024asynchronous. The TBA frontier is created by altering the training steps, searcher count, and KL annealing schedule as described in Appendix \ref{['app:hparams']}.
  • Figure 5: TBA reaches the RT diversity-toxicity Pareto frontier and improves as search is scaled. (Left) On the GPT-2 automated red-teaming task of lee2024learning, TBA produces results on the diversity vs. toxicity Pareto frontier in less training time. Baselines taken from lee2024learning. (Right) Each searcher uses one V100 GPU for generating attacks. We report means and standard errors from multiple runs of the automated red-teaming task with GPT-2 at each searcher/GPU count.
  • ...and 9 more figures