Table of Contents
Fetching ...

Parallel Continuous Chain-of-Thought with Jacobi Iteration

Haoyi Wu, Zhihao Teng, Kewei Tu

TL;DR

This paper proposes Parallel Continuous Chain-of-Thought (PCCoT), which performs Jacobi iteration on the latent thought tokens, updating them iteratively in parallel instead of sequentially and thus improving both training and inference efficiency of continuous CoT.

Abstract

Continuous chain-of-thought has been shown to be effective in saving reasoning tokens for large language models. By reasoning with continuous latent thought tokens, continuous CoT is able to perform implicit reasoning in a compact manner. However, the sequential dependencies between latent thought tokens spoil parallel training, leading to long training time. In this paper, we propose Parallel Continuous Chain-of-Thought (PCCoT), which performs Jacobi iteration on the latent thought tokens, updating them iteratively in parallel instead of sequentially and thus improving both training and inference efficiency of continuous CoT. Experiments demonstrate that by choosing the proper number of iterations, we are able to achieve comparable or even better performance while saving nearly 50% of the training and inference time. Moreover, PCCoT shows better stability and robustness in the training process. Our code is available at https://github.com/whyNLP/PCCoT.

Parallel Continuous Chain-of-Thought with Jacobi Iteration

TL;DR

This paper proposes Parallel Continuous Chain-of-Thought (PCCoT), which performs Jacobi iteration on the latent thought tokens, updating them iteratively in parallel instead of sequentially and thus improving both training and inference efficiency of continuous CoT.

Abstract

Continuous chain-of-thought has been shown to be effective in saving reasoning tokens for large language models. By reasoning with continuous latent thought tokens, continuous CoT is able to perform implicit reasoning in a compact manner. However, the sequential dependencies between latent thought tokens spoil parallel training, leading to long training time. In this paper, we propose Parallel Continuous Chain-of-Thought (PCCoT), which performs Jacobi iteration on the latent thought tokens, updating them iteratively in parallel instead of sequentially and thus improving both training and inference efficiency of continuous CoT. Experiments demonstrate that by choosing the proper number of iterations, we are able to achieve comparable or even better performance while saving nearly 50% of the training and inference time. Moreover, PCCoT shows better stability and robustness in the training process. Our code is available at https://github.com/whyNLP/PCCoT.

Paper Structure

This paper contains 47 sections, 1 theorem, 12 equations, 6 figures, 7 tables.

Key Result

Theorem 1

The computation graph of PCCoT with $c$ latent thought tokens and $T$ extra iterations is equivalent to that of continuous CoT with $c$ latent thought tokens if $T \geq c$.

Figures (6)

  • Figure 1: An illustration of Continuous Chain-of-Thought (left) and Parallel Continuous Chain-of-Thought (right). The figure shows $c=3$ latent thought tokens with the first forward pass and $T=2$ extra iterations. The <eot> token and the answer tokens are not shown in the figure. Each dashed box represents a single forward pass.
  • Figure 2: Test set accuracy (%) of PCCoT with different number of extra iterations $T$ and latent thought tokens $c$ on GSM8K-Aug. The figure shows the average over 3 random runs with standard deviation.
  • Figure 3: Test set accuracy (%) of PCCoT with different latent thought tokens $c$ and number of extra iterations $T$ on GSM8K-Aug. The figure shows the average over 3 random runs with standard deviation.
  • Figure 4: MSE of the latent thought tokens before and after the $t$th extra iteration. "rand" means the model is randomly initialized. Other models are trained with $c=24$ and different $T$. The model is tested on random samples from the test set of GSM8K.
  • Figure 5: MSE between the latent thought tokens. The darker the block is, the more similar the latent thought tokens are. The model is tested on random samples from the test set of GSM8K.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Theorem 1
  • proof