Table of Contents
Fetching ...

Training Large Language Models To Reason In Parallel With Global Forking Tokens

Sheng Jia, Xiao Wang, Shiva Prasad Kasiviswanathan

TL;DR

This work addresses the challenge of achieving diverse yet accurate parallel reasoning in large language models without sacrificing correctness. The authors introduce Set Supervised Fine-Tuning (SSFT), which uses a set of global forking tokens to initiate multiple reasoning traces in parallel and a set-based, bipartite-matching loss to align forks with traces, enforcing permutation invariance and preventing collapse of reasoning modes. The approach yields emergent global fork tokens and consistent improvements in Pass@1 and Cons@k across reasoning benchmarks, outperforming standard SFT and naive multi-trace fine-tuning. By enabling coverage-aware training and leveraging distilled traces, SSFT demonstrates a practical pathway to scalable parallel reasoning with improved robustness, while maintaining computational efficiency via a Hungarian-matching-based training loop. The method holds promise for enhancing interpretability and reliability in reasoning-heavy tasks and invites further exploration of scaling fork-token sets and extending to broader evaluation domains.

Abstract

Although LLMs have demonstrated improved performance by scaling parallel test-time compute, doing so relies on generating reasoning paths that are both diverse and accurate. For challenging problems, the forking tokens that trigger diverse yet correct reasoning modes are typically deep in the sampling tree. Consequently, common strategies to encourage diversity, such as temperature scaling, encounter a worsened trade-off between diversity and accuracy. Motivated by this challenge, we treat parallel reasoning as a set-of-next-token-prediction problem, and incorporate a set-based global loss into Supervised Fine-Tuning (SFT) using self-supervised bipartite matching between our global forking tokens and unique reasoning traces. We observe that, while naive fine-tuning with multiple reasoning traces collapses these unique reasoning modes, our proposed method, Set Supervised Fine-Tuning (SSFT), preserves these modes and produces emergent global forking tokens. Experiments on multiple reasoning benchmarks show that our SSFT consistently outperforms SFT under both Pass@1 and Cons@k metrics.

Training Large Language Models To Reason In Parallel With Global Forking Tokens

TL;DR

This work addresses the challenge of achieving diverse yet accurate parallel reasoning in large language models without sacrificing correctness. The authors introduce Set Supervised Fine-Tuning (SSFT), which uses a set of global forking tokens to initiate multiple reasoning traces in parallel and a set-based, bipartite-matching loss to align forks with traces, enforcing permutation invariance and preventing collapse of reasoning modes. The approach yields emergent global fork tokens and consistent improvements in Pass@1 and Cons@k across reasoning benchmarks, outperforming standard SFT and naive multi-trace fine-tuning. By enabling coverage-aware training and leveraging distilled traces, SSFT demonstrates a practical pathway to scalable parallel reasoning with improved robustness, while maintaining computational efficiency via a Hungarian-matching-based training loop. The method holds promise for enhancing interpretability and reliability in reasoning-heavy tasks and invites further exploration of scaling fork-token sets and extending to broader evaluation domains.

Abstract

Although LLMs have demonstrated improved performance by scaling parallel test-time compute, doing so relies on generating reasoning paths that are both diverse and accurate. For challenging problems, the forking tokens that trigger diverse yet correct reasoning modes are typically deep in the sampling tree. Consequently, common strategies to encourage diversity, such as temperature scaling, encounter a worsened trade-off between diversity and accuracy. Motivated by this challenge, we treat parallel reasoning as a set-of-next-token-prediction problem, and incorporate a set-based global loss into Supervised Fine-Tuning (SFT) using self-supervised bipartite matching between our global forking tokens and unique reasoning traces. We observe that, while naive fine-tuning with multiple reasoning traces collapses these unique reasoning modes, our proposed method, Set Supervised Fine-Tuning (SSFT), preserves these modes and produces emergent global forking tokens. Experiments on multiple reasoning benchmarks show that our SSFT consistently outperforms SFT under both Pass@1 and Cons@k metrics.

Paper Structure

This paper contains 23 sections, 7 equations, 13 figures, 3 tables, 2 algorithms.

Figures (13)

  • Figure 1: An illustration of different supervised fine-tuning methods that aim to instill parallel reasoning capabilities from diverse reasoning traces for the same question. Compared to (1) standard SFT and (2) SFT with randomly assigned parallel thinking identifiers, (3) Set-Supervised Fine-Tuning leverages a self-supervised bipartite matching process to learn to maximally differentiate the reasoning modes conditioned on distinct $\texttt{<think\,i>}\xspace$, for each question. Self-supervised matching prevents collapse of maximally distinct reasoning modes caused by ordering bias from randomly or manually assigning a reasoning mode identifier to a reasoning trace. The learned global forking tokens are considered emergent because we observe that similar reasoning modes over the traces for different questions cluster to match the same $\texttt{<think,i>}\xspace$ at convergence, even though we do not manually identify such modes or incorporate any regularization term to maintain consistency in matching over different input prompts.
  • Figure 2: An illustration of one SSFT training step. Step 1: We first construct the cost matrix by evaluating all pairwise combinations: for each ${\mathbf{r}}^{(j)}\in \lbrace{\mathbf{r}}^{(1)}, {\mathbf{r}}^{(2)}, {\mathbf{r}}^{(3)}, {\mathbf{r}}^{(4)}\rbrace$ and each ${\textnormal{g}}^{(i)} \in \lbrace{\textnormal{g}}^{(1)}, {\textnormal{g}}^{(2)}, {\textnormal{g}}^{(3)}, {\textnormal{g}}^{(4)}, {\textnormal{g}}^{(5)}, {\textnormal{g}}^{(6)}\rbrace$, we compute the NTP loss of ${\mathbf{r}}^{(j)}$ conditioned on ${\textnormal{g}}^{(i)}$ (Equation (\ref{['eqn:matching_cost_eqn']})). Then we use Hungarian algorithm to find $\hat{\bm{\sigma}}$ that minimizes the total bipartite matching cost. Here, this minimum is the sum of the losses highlighted in blue, which means $\hat{\bm{\sigma}} = \lbrace({\textnormal{g}}^{(6)}, {\mathbf{r}}^{(1)}), ({\textnormal{g}}^{(5)}, {\mathbf{r}}^{(2)}), ({\textnormal{g}}^{(2)}, {\mathbf{r}}^{(3)}), ({\textnormal{g}}^{(3)}, {\mathbf{r}}^{(4)})\rbrace$. Step 2: We optimize ${\bm{\theta}}$ by backpropagating the set of NTP losses for ${\mathbf{r}}^{(j)}$, each conditioned on ${\textnormal{g}}^{(\hat{\bm{\sigma}}(j))}$. This is the Hungarian loss in Equation \ref{['eqn:hungarianloss']}.
  • Figure 3: (Left)$\bm{\mathfrak{S}}_p$ is the subset of bipartite matching configurations that are still computed optimal towards the end of training. These can be individually visualized as a bipartite (Middle) Learned matchings by SSFT-32B in Exp \ref{['section:experiments']}, obtained by aggregating all edges in $\bm{\mathfrak{S}}_p$. (Right) At test time, for $\mathrm{Pass@}1$, we prompt with ${\textnormal{g}}^{(i^\star)}$ that has the most connected edges. For $\mathrm{Cons@}k$, we prompt $i$-th parallel generation with $\texttt{<think(i \% $\mathrm{N}$)>}\xspace$.
  • Figure 4: Coverage of SSFT compared to SFT-mixed-distill-32B with temperature scaling, reported at $\mathrm{Pass@}k$. For convenience, we also report the $\mathrm{Cons@}6$ accuracy next to each line. In AIME25, SFT-mixed-distill-32B needs to raise the inference temperature to 1 and use more attempts to match the coverage at the cost of lowering its $\mathrm{Pass@}1$ and $\mathrm{Cons@}6$ accuracy, further widening the gaps.
  • Figure 5: (SSFT, optimal matching). Distribution of thinking-token counts and average performance on AIME24 (left) and AIME25 (right) prompted by a distinct $\texttt{<think1>}\xspace,\dots,\texttt{<think6>}\xspace$. Each accuracy is averaged over 11 generations, shown above each whisker. We observe clear specialization of reasoning modes. $\texttt{<think1>}\xspace,\dots,\texttt{<think4>}\xspace$ trigger longer traces, better suited to many AIME problems, and these surpass all the average accuracy under SFT without bipartite matching in Figure \ref{['fig:randommatchingcollapsereasoning']}. $\texttt{<think5>}\xspace$ and $\texttt{<think6>}\xspace$ favor concise reasoning, which can be suboptimal for some AIME tasks, yet still raise $\mathrm{Cons@}6$ beyond all baselines by adding diversity. The consistency of these length distributions across AIME24 and AIME25 indicates the difference is not due to randomness, and that these think tags truly initiate distinct yet consistent reasoning modes.
  • ...and 8 more figures