Table of Contents
Fetching ...

Expediting and Elevating Large Language Model Reasoning via Hidden Chain-of-Thought Decoding

Tianqiao Liu, Zui Chen, Zitao Liu, Mi Tian, Weiqi Luo

TL;DR

The paper tackles the latency of explicit chain-of-thought reasoning in large language models by introducing Hidden Chain-of-Thought (HCoT), a two-stage framework that compresses multi-step reasoning into a semantically aligned [CoT] token via an auxiliary CoT module. The auxiliary CoT is trained with a cross-entropy objective plus a symmetric contrastive loss to produce compact representations, which the HCoT component uses to generate subsequent outputs conditioned on this compressed reasoning. Empirical results across GSM8K, MATH, ScienceQA, and HotpotQA show competitive or improved accuracy with substantial decoding speedups (1.5x–3.8x) and demonstrate that contrastive learning further enhances representation quality and task performance. The approach offers a path toward scalable, interpretable, and faster multi-step reasoning in real-world applications, albeit with increased training complexity and resource requirements that warrant further optimization.

Abstract

Large language models (LLMs) have demonstrated remarkable capabilities in tasks requiring reasoning and multi-step problem-solving through the use of chain-of-thought (CoT) prompting. However, generating the full CoT process results in significantly longer output sequences, leading to increased computational costs and latency during inference. To address this challenge, we propose a novel approach to compress the CoT process through semantic alignment, enabling more efficient decoding while preserving the benefits of CoT reasoning. Our method introduces an auxiliary CoT model that learns to generate and compress the full thought process into a compact special token representation semantically aligned with the original CoT output. This compressed representation is then integrated into the input of the Hidden Chain-of-Thought (HCoT) model. The training process follows a two-stage procedure: First, the CoT model is optimized to generate the compressed token representations aligned with the ground-truth CoT outputs using a contrastive loss. Subsequently, with the CoT model parameters frozen, the HCoT model is fine-tuned to generate accurate subsequent predictions conditioned on the prefix instruction and the compressed CoT representations from the CoT model. Extensive experiments across three challenging domains - mathematical reasoning, agent invocation, and question answering - demonstrate that our semantic compression approach achieves competitive or improved performance compared to the full CoT baseline, while providing significant speedups of at least 1.5x in decoding time. Moreover, incorporating contrastive learning objectives further enhances the quality of the compressed representations, leading to better CoT prompting and improved task accuracy. Our work paves the way for more efficient exploitation of multi-step reasoning capabilities in LLMs across a wide range of applications.

Expediting and Elevating Large Language Model Reasoning via Hidden Chain-of-Thought Decoding

TL;DR

The paper tackles the latency of explicit chain-of-thought reasoning in large language models by introducing Hidden Chain-of-Thought (HCoT), a two-stage framework that compresses multi-step reasoning into a semantically aligned [CoT] token via an auxiliary CoT module. The auxiliary CoT is trained with a cross-entropy objective plus a symmetric contrastive loss to produce compact representations, which the HCoT component uses to generate subsequent outputs conditioned on this compressed reasoning. Empirical results across GSM8K, MATH, ScienceQA, and HotpotQA show competitive or improved accuracy with substantial decoding speedups (1.5x–3.8x) and demonstrate that contrastive learning further enhances representation quality and task performance. The approach offers a path toward scalable, interpretable, and faster multi-step reasoning in real-world applications, albeit with increased training complexity and resource requirements that warrant further optimization.

Abstract

Large language models (LLMs) have demonstrated remarkable capabilities in tasks requiring reasoning and multi-step problem-solving through the use of chain-of-thought (CoT) prompting. However, generating the full CoT process results in significantly longer output sequences, leading to increased computational costs and latency during inference. To address this challenge, we propose a novel approach to compress the CoT process through semantic alignment, enabling more efficient decoding while preserving the benefits of CoT reasoning. Our method introduces an auxiliary CoT model that learns to generate and compress the full thought process into a compact special token representation semantically aligned with the original CoT output. This compressed representation is then integrated into the input of the Hidden Chain-of-Thought (HCoT) model. The training process follows a two-stage procedure: First, the CoT model is optimized to generate the compressed token representations aligned with the ground-truth CoT outputs using a contrastive loss. Subsequently, with the CoT model parameters frozen, the HCoT model is fine-tuned to generate accurate subsequent predictions conditioned on the prefix instruction and the compressed CoT representations from the CoT model. Extensive experiments across three challenging domains - mathematical reasoning, agent invocation, and question answering - demonstrate that our semantic compression approach achieves competitive or improved performance compared to the full CoT baseline, while providing significant speedups of at least 1.5x in decoding time. Moreover, incorporating contrastive learning objectives further enhances the quality of the compressed representations, leading to better CoT prompting and improved task accuracy. Our work paves the way for more efficient exploitation of multi-step reasoning capabilities in LLMs across a wide range of applications.
Paper Structure (37 sections, 2 equations, 20 figures, 3 tables)

This paper contains 37 sections, 2 equations, 20 figures, 3 tables.

Figures (20)

  • Figure 1: Real-world examples of the CoT prompting in tasks such as mathematical reasoning, question answering, and agent invocation. In the figure, green parts represent actual user queries, blue strikethroughs indicate compressed thought processes, and black text denotes non-CoT content, which corresponds to the expected output for users.
  • Figure 2: Data construction and two-stage training of Hidden Chain-of-Thought (HCoT) models for math reasoning tasks: Training instances are synthetically generated from raw data using GPT-4, then utilized separately for training the Auxiliary Chain-of-Thought Model and the HCoT Model.
  • Figure :
  • Figure :
  • Figure :
  • ...and 15 more figures