Table of Contents
Fetching ...

CoLT: Reasoning with Chain of Latent Tool Calls

Fangwei Zhu, Zhifang Sui

TL;DR

CoLT tackles the inefficiency of explicit Chain-of-Thought by reframing latent reasoning as differentiable tool calls driven by seed tokens that are decoded back into explicit text. It introduces body and trigger seed tokens, two lightweight decoders, and a joint training regime with supervised objectives and optional reinforcement learning via GRPO, preserving the main model's token-space reasoning while speeding up inference. Empirically, CoLT achieves higher accuracy and shorter reasoning chains across GSM8k-Aug and out-of-domain math datasets, and shows compatibility with RL to further improve performance. The approach offers a practical path to scalable, interpretable reasoning in LLMs and can be extended to multimodal reasoning and other domains with appropriate decoders.

Abstract

Chain-of-Thought (CoT) is a critical technique in enhancing the reasoning ability of Large Language Models (LLMs), and latent reasoning methods have been proposed to accelerate the inefficient token-level reasoning chain. We notice that existing latent reasoning methods generally require model structure augmentation and exhaustive training, limiting their broader applicability. In this paper, we propose CoLT, a novel framework that implements latent reasoning as ``tool calls''. Instead of reasoning entirely in the latent space, CoLT generates seed tokens that contain information of a reasoning step. When a latent tool call is triggered, a smaller external model will take the hidden states of seed tokens as its input, and unpack the seed tokens back to a full reasoning step. In this way, we can ensure that the main model reasons in the explicit token space, preserving its ability while improving efficiency. Experimental results on four mathematical datasets demonstrate that CoLT achieves higher accuracy and shorter reasoning length than baseline latent models, and is compatible with reinforcement learning algorithms and different decoder structures.

CoLT: Reasoning with Chain of Latent Tool Calls

TL;DR

CoLT tackles the inefficiency of explicit Chain-of-Thought by reframing latent reasoning as differentiable tool calls driven by seed tokens that are decoded back into explicit text. It introduces body and trigger seed tokens, two lightweight decoders, and a joint training regime with supervised objectives and optional reinforcement learning via GRPO, preserving the main model's token-space reasoning while speeding up inference. Empirically, CoLT achieves higher accuracy and shorter reasoning chains across GSM8k-Aug and out-of-domain math datasets, and shows compatibility with RL to further improve performance. The approach offers a practical path to scalable, interpretable reasoning in LLMs and can be extended to multimodal reasoning and other domains with appropriate decoders.

Abstract

Chain-of-Thought (CoT) is a critical technique in enhancing the reasoning ability of Large Language Models (LLMs), and latent reasoning methods have been proposed to accelerate the inefficient token-level reasoning chain. We notice that existing latent reasoning methods generally require model structure augmentation and exhaustive training, limiting their broader applicability. In this paper, we propose CoLT, a novel framework that implements latent reasoning as ``tool calls''. Instead of reasoning entirely in the latent space, CoLT generates seed tokens that contain information of a reasoning step. When a latent tool call is triggered, a smaller external model will take the hidden states of seed tokens as its input, and unpack the seed tokens back to a full reasoning step. In this way, we can ensure that the main model reasons in the explicit token space, preserving its ability while improving efficiency. Experimental results on four mathematical datasets demonstrate that CoLT achieves higher accuracy and shorter reasoning length than baseline latent models, and is compatible with reinforcement learning algorithms and different decoder structures.
Paper Structure (42 sections, 11 equations, 4 figures, 6 tables)

This paper contains 42 sections, 11 equations, 4 figures, 6 tables.

Figures (4)

  • Figure 1: The CoLT framework with a standard LLM backbone and differentiable decoders. During inference, the backbone LLM can raise a latent tool call consisting of special body tokens <BDY> and trigger tokens <TRG>. CoLT will select an external decoder according to the trigger token to decode the hidden states of the latent tool call back to normal tokens. The LLM will continue its reasoning with the updated context until it reaches the final answer.
  • Figure 2: How accuracy and latent length change with the number of seed tokens and decoder layers.
  • Figure 3: Case study of CoLT (1 seed token) on the GSM8K-Aug dataset. For each reasoning step, the main model raises a latent tool call with the <TRG> token, and the decoder decodes the hidden states of seed tokens to equations, which then replaces the seed tokens to form the new context. The input of the main model at each reasoning step is purely textual, and so is the final reasoning chain.
  • Figure 4: How accuracy and latent length change with the number of training epochs.