Table of Contents
Fetching ...

SplitReason: Learning To Offload Reasoning

Yash Akhauri, Anthony Fei, Chi-Chih Chang, Ahmed F. AbouElhamayed, Yueying Li, Mohamed S. Abdelfattah

TL;DR

SplitReason tackles the high latency of reasoning in LLMs by enabling a small model to offload only the most difficult reasoning steps to a larger model. It introduces cooperative execution with control tokens and a two-stage training pipeline (SFT followed by GRPO) under a latency-aware objective (RL4E). On an annotated OpenR1-Math-220k corpus, the method yields up to twenty-eight percent accuracy gains with as little as five percent of the reasoning offloaded, while simulations indicate substantial end-to-end speedups. The approach is model-agnostic and open-source, offering a new direction for hardware-aware inference and efficient reasoning in LLMs.

Abstract

Reasoning in large language models (LLMs) tends to produce substantially longer token generation sequences than simpler language modeling tasks. This extended generation length reflects the multi-step, compositional nature of reasoning and is often correlated with higher solution accuracy. From an efficiency perspective, longer token generation exacerbates the inherently sequential and memory-bound decoding phase of LLMs. However, not all parts of this expensive reasoning process are equally difficult to generate. We leverage this observation by offloading only the most challenging parts of the reasoning process to a larger, more capable model, while performing most of the generation with a smaller, more efficient model; furthermore, we teach the smaller model to identify these difficult segments and independently trigger offloading when needed. To enable this behavior, we annotate difficult segments across 18k reasoning traces from the OpenR1-Math-220k chain-of-thought (CoT) dataset. We then apply supervised fine-tuning (SFT) and reinforcement learning fine-tuning (RLFT) to a 1.5B-parameter reasoning model, training it to learn to offload the most challenging parts of its own reasoning process to a larger model. This approach improves AIME24 reasoning accuracy by 24% and 28.3% while offloading 1.35% and 5% of the generated tokens respectively. We open-source our SplitReason model, data, code and logs.

SplitReason: Learning To Offload Reasoning

TL;DR

SplitReason tackles the high latency of reasoning in LLMs by enabling a small model to offload only the most difficult reasoning steps to a larger model. It introduces cooperative execution with control tokens and a two-stage training pipeline (SFT followed by GRPO) under a latency-aware objective (RL4E). On an annotated OpenR1-Math-220k corpus, the method yields up to twenty-eight percent accuracy gains with as little as five percent of the reasoning offloaded, while simulations indicate substantial end-to-end speedups. The approach is model-agnostic and open-source, offering a new direction for hardware-aware inference and efficient reasoning in LLMs.

Abstract

Reasoning in large language models (LLMs) tends to produce substantially longer token generation sequences than simpler language modeling tasks. This extended generation length reflects the multi-step, compositional nature of reasoning and is often correlated with higher solution accuracy. From an efficiency perspective, longer token generation exacerbates the inherently sequential and memory-bound decoding phase of LLMs. However, not all parts of this expensive reasoning process are equally difficult to generate. We leverage this observation by offloading only the most challenging parts of the reasoning process to a larger, more capable model, while performing most of the generation with a smaller, more efficient model; furthermore, we teach the smaller model to identify these difficult segments and independently trigger offloading when needed. To enable this behavior, we annotate difficult segments across 18k reasoning traces from the OpenR1-Math-220k chain-of-thought (CoT) dataset. We then apply supervised fine-tuning (SFT) and reinforcement learning fine-tuning (RLFT) to a 1.5B-parameter reasoning model, training it to learn to offload the most challenging parts of its own reasoning process to a larger model. This approach improves AIME24 reasoning accuracy by 24% and 28.3% while offloading 1.35% and 5% of the generated tokens respectively. We open-source our SplitReason model, data, code and logs.

Paper Structure

This paper contains 13 sections, 9 figures.

Figures (9)

  • Figure 1: SplitReason intelligently offloads token generation to a large model during difficult parts of the reasoning process. Leveraging a small model (1.5B parameters) for majority of the decode process leads to significant end-to-end speedup compared to the large model (32B parameters), while improving accuracy over the small model.
  • Figure 2: SplitReason utilizes two models to perform fast and high-accuracy reasoning. A small model is fine-tuned to emit a <bigmodel> tag when it detects a difficult reasoning step. This triggers a large model to step in and take over generation until a </bigmodel> tag is detected.
  • Figure 3: With SplitReason, the small model (1.5B) acts as the controller. While the small model is decoding, the large model keeps up with the generations by doing streaming prefills to keep its KV-Cache updated. Once the small model emits <bigmodel> tag, the large model takes over generation. At this time, the small model does controlling prefills, this serves a dual purpose, keeping the KV-Cache updated, as well as checking if the small model wants to take back control. The generation is halted for the large model if the small model emits </bigmodel> during its controlling prefill, and the small model takes over decode.
  • Figure 4: We take the entire response for a question from OpenR1-Math-220k and prompt deepseek-chat to annotate difficult portions of the response. These spans are encased in our (<bigmodel>, </bigmodel>) tags.
  • Figure 5: (Left) Randomly offloading sections of the decode process from a 1.5B model to 32B model boosts AIME24 accuracy by up to 20%. Our learned offloading achieves even higher gains in accuracy (24%--28%) with just a 1.35%--5% offload. (Right) We run pipelined performance simulations by profiling a range of models on A6000 GPUs and find that at a 1.35% offload, we can expect 8-9$\times$ faster inference over the large model.
  • ...and 4 more figures