Table of Contents
Fetching ...

Tree Training: Accelerating Agentic LLMs Training via Shared Prefix Reuse

Shaojie Wang, Jinghui Wang, Yinghan Cui, Xuxing Chen, Chao Wang, Liang Huang, Xiaojiang Zhang, Junyi Peng, Li Wan, Haotian Zhang, Bin Chen

TL;DR

This work tackles the inefficiency in training agentic LLMs where rollout trajectories form tree structures due to multi-turn interactions. It introduces Tree Training, consisting of Tree Packing to compactly represent shared prefixes and Gradient Restoration to ensure correct gradient flow across reused prefixes, complemented by optimized kernels. The approach enables one-time computation of shared prefixes and reuse across descendant branches, yielding up to 3.9× end-to-end speedups in realistic agentic RL data and up to 5.7× in ideal memory-rich scenarios, without degrading fidelity. Practically, these techniques significantly boost throughput for large-scale supervised fine-tuning and reinforcement learning of agentic LLMs, while maintaining training correctness through formal gradient analyses.

Abstract

In agentic LLM scenarios, an agent's interaction process during a single rollout often exhibits branching behaviors. Due to memory retrieval and concurrent tool executions at certain decision points, the token trajectory of one task evolves into a tree-like structure rather than a linear sequence. However, current training pipelines decompose such tree-structured trajectories into separate linear segments, treating each branch as an independent sequence. As a result, shared prefixes across these branches are repeatedly recomputed during both forward and backward passes. To address this inefficiency, we propose Tree Training, a paradigm that computes each shared prefix only once and reuses its intermediate results across related branches during both forward and backward passes, substantially improving computation efficiency in large-scale agentic training. This is achieved via (i) Tree Packing, which efficiently reuses shared computations across trajectories, and (ii) Gradient Restoration, which ensures correct gradient propagation across reused prefixes. Experiments on multiple open-source models demonstrate up to 3.9x reduction in total training time, enabling more efficient agentic LLM SFT and RL training.

Tree Training: Accelerating Agentic LLMs Training via Shared Prefix Reuse

TL;DR

This work tackles the inefficiency in training agentic LLMs where rollout trajectories form tree structures due to multi-turn interactions. It introduces Tree Training, consisting of Tree Packing to compactly represent shared prefixes and Gradient Restoration to ensure correct gradient flow across reused prefixes, complemented by optimized kernels. The approach enables one-time computation of shared prefixes and reuse across descendant branches, yielding up to 3.9× end-to-end speedups in realistic agentic RL data and up to 5.7× in ideal memory-rich scenarios, without degrading fidelity. Practically, these techniques significantly boost throughput for large-scale supervised fine-tuning and reinforcement learning of agentic LLMs, while maintaining training correctness through formal gradient analyses.

Abstract

In agentic LLM scenarios, an agent's interaction process during a single rollout often exhibits branching behaviors. Due to memory retrieval and concurrent tool executions at certain decision points, the token trajectory of one task evolves into a tree-like structure rather than a linear sequence. However, current training pipelines decompose such tree-structured trajectories into separate linear segments, treating each branch as an independent sequence. As a result, shared prefixes across these branches are repeatedly recomputed during both forward and backward passes. To address this inefficiency, we propose Tree Training, a paradigm that computes each shared prefix only once and reuses its intermediate results across related branches during both forward and backward passes, substantially improving computation efficiency in large-scale agentic training. This is achieved via (i) Tree Packing, which efficiently reuses shared computations across trajectories, and (ii) Gradient Restoration, which ensures correct gradient propagation across reused prefixes. Experiments on multiple open-source models demonstrate up to 3.9x reduction in total training time, enabling more efficient agentic LLM SFT and RL training.

Paper Structure

This paper contains 29 sections, 21 equations, 8 figures.

Figures (8)

  • Figure 1: Illustration of shared prefixes.Left: Multiple trajectories share common prefix segments (e.g., all trajectories share $r \!\rightarrow\! u$), while smaller subsets may share longer prefixes (e.g., trajectories 1–3 share $r \!\rightarrow\! u \!\rightarrow\! v_1$). Right: Merging these overlapping prefixes forms a hierarchical tree, where shared computation is explicitly represented by internal nodes, and unique continuations correspond to leaf branches. This tree structure enables efficient reuse of common computation across trajectories.
  • Figure 2: Schematic of the preprocess (sequence packing), forward pass, and backward pass in a tree-structured dataset. Pink blocks represent the $(X, Q, K, V, O, dO, dV)$ for the prefix, while yellow blocks correspond to those of the suffix parts.
  • Figure 3: Comparison between single-path and multi-path tree packing. Step 1 packs the shared prefix $r \!\rightarrow\! u \!\rightarrow\! v_1$, and Step 2 packs $r \!\rightarrow\! u \!\rightarrow\! v_5$ separately. The optimal strategy instead treats $r \!\rightarrow\! u \!\rightarrow\! \{v_1, v_5\}$ as a hierarchical shared prefix, enabling greater computation reuse, as discussed in section \ref{['multi-path']}.
  • Figure 4: A Comparative Illustration of Backward V-Gradient Computation Tree-Packing and Original Packing. Pink blocks represent the $(Q, K, dO, dV)$ for the prefix parts, while yellow blocks correspond to those of the suffix parts. The orange blocks signify that their corresponding (S or P) values are active in the attention computation, whereas the white blocks indicate that their (S or P) values are masked out, the green blocks represent the (S or P) computations that can be omitted in our tree packing.
  • Figure 5: Implementation of flattened tree trajectory. Each flattened tree trajectory requires (1) a gradient scale tensor for prefix reuse, (2) a position embedding tensor that restores original token positions, and (3) a shared-prefix attention mask that enables proper computation reuse across overlapping prefixes during both forward and backward passes.
  • ...and 3 more figures