Table of Contents
Fetching ...

CoTFormer: A Chain-of-Thought Driven Architecture with Budget-Adaptive Computation Cost at Inference

Amirkeivan Mohtashami, Matteo Pagliardini, Martin Jaggi

TL;DR

The paper introduces CoTFormer, a chain-of-thought-inspired transformer that attends to past intermediate representations, effectively simulating token-level depth without explicit deep weight-tying. It demonstrates that CoTFormer outperforms Block Universal Transformer on perplexity while offering favorable compute characteristics, moving the perplexity-compute Pareto frontier forward. To enable inference-time budget control, the authors develop Mixture of Repeats and an adaptive training scheme with depth embeddings and reserved layers, achieving substantial compute savings with minimal accuracy loss. The work provides detailed ablations and comparisons, and outlines practical directions for longer training and improved adaptive sampling to further close the gap to much larger models.

Abstract

Scaling language models to larger and deeper sizes has led to significant boosts in performance. Even though the size of these models limits their application in compute-constrained environments, the race to continually develop ever larger and deeper foundational models is underway. At the same time -- regardless of the model size -- task-specific techniques continue to play a pivotal role in achieving optimal downstream performance. One of these techniques, called Chain-of-Thought (CoT), is particularly interesting since, as we point out in this work, it resembles employing a deeper transformer through re-applying the model multiple times. However, a key subtlety in computing the attention of past tokens differentiates CoT from simply applying the model several times. Based on this insight, we propose CoTFormer, a novel architecture which closely mimics CoT at the token level, allowing us to obtain significantly improved accuracies close to much larger models. While applying CoT introduces additional computation costs, we compensate for it by leveraging CoTFormer's special compatibility with token-wise variable depth. Through a compute adaptive model -- which automatically allocates the compute to tokens that need it most -- we show that it is possible to reduce the computation cost significantly without any reduction in accuracy, and with further compute cost reductions possible while maintaining a competitive accuracy.

CoTFormer: A Chain-of-Thought Driven Architecture with Budget-Adaptive Computation Cost at Inference

TL;DR

The paper introduces CoTFormer, a chain-of-thought-inspired transformer that attends to past intermediate representations, effectively simulating token-level depth without explicit deep weight-tying. It demonstrates that CoTFormer outperforms Block Universal Transformer on perplexity while offering favorable compute characteristics, moving the perplexity-compute Pareto frontier forward. To enable inference-time budget control, the authors develop Mixture of Repeats and an adaptive training scheme with depth embeddings and reserved layers, achieving substantial compute savings with minimal accuracy loss. The work provides detailed ablations and comparisons, and outlines practical directions for longer training and improved adaptive sampling to further close the gap to much larger models.

Abstract

Scaling language models to larger and deeper sizes has led to significant boosts in performance. Even though the size of these models limits their application in compute-constrained environments, the race to continually develop ever larger and deeper foundational models is underway. At the same time -- regardless of the model size -- task-specific techniques continue to play a pivotal role in achieving optimal downstream performance. One of these techniques, called Chain-of-Thought (CoT), is particularly interesting since, as we point out in this work, it resembles employing a deeper transformer through re-applying the model multiple times. However, a key subtlety in computing the attention of past tokens differentiates CoT from simply applying the model several times. Based on this insight, we propose CoTFormer, a novel architecture which closely mimics CoT at the token level, allowing us to obtain significantly improved accuracies close to much larger models. While applying CoT introduces additional computation costs, we compensate for it by leveraging CoTFormer's special compatibility with token-wise variable depth. Through a compute adaptive model -- which automatically allocates the compute to tokens that need it most -- we show that it is possible to reduce the computation cost significantly without any reduction in accuracy, and with further compute cost reductions possible while maintaining a competitive accuracy.
Paper Structure (18 sections, 3 equations, 5 figures, 4 tables)

This paper contains 18 sections, 3 equations, 5 figures, 4 tables.

Figures (5)

  • Figure 1: Block universal transformer vs. CoTFormer vs. Chain-of-Thought (CoT) reasoning. In (a) we represent the chain-of-thought mechanism in which a model is iteratively generating reasoning tokens to help solve downstream applications. Based on existing input red tokens, a next token (blue) is generated and added to the sequence, re-iterating this process yields the green and yellow tokens. Here we emphasize how (i) the last red tokens is "pushed" several times through the model---the yellow token being the red token after three successive applications of the model---and (ii), new (e.g. blue) tokens can attend to previous (e.g. red) tokens, this observation is the basis of CoTFormer. In (b) we represent the block-universal transformer which recursively applies the same $N$ transformer blocks to the input tokens. This approach is to be contrasted with the CoTFormer architecture (c) which interleaves old and new representations in between each block. In the figure this process is done two times (${n_\text{repeat}}=2$), but could be repeated many more times. As in CoT, and unlike block universal transformers, later (e.g. blue) tokens can attend to earlier (e.g. red) tokens.
  • Figure 2: Comparison of Block Universal Transformer and CoTFormer in terms of accuracy-computation tradeoff. It can be clearly seen that at both ${n_{\text{layer}}}=12$ and ${n_{\text{layer}}}=24$, CoTFormers are closer to the Pareto frontier. The gap widens with larger number of repeats, suggesting better scaling properties of CoTFormers.
  • Figure 3: CoTFormer is less compute intensive than a Block Universal Transformer of comparable performance. Comparing a $12$ layers CoTFormer with $3$ repeats ($12\textsf{x}3$) and a $12$ layer Block Universal Transformer with $5$ repeats ($12\textsf{x}5$) in terms of computation cost. The CoTFormer's accuracy is better than the Block Universal Transformer (see Figure \ref{['fig:cot_vs_block_univ_pareto']}). Despite the increase in context length when processing the input with CoTFormer, the computational cost of CoTFormer remains below the Block Universal Transformer for sequence lengths as high as 8192.
  • Figure 4: Perplexity for different amount of compute budgets chosen at inference. The adaptive CoTFormer can adapt to different budgets, reducing compute in exchange for reasonable loss in accuracy. Furthermore, using the router weights to allocate the available compute (Router) is much more effective than fixing the depth at inference time to a smaller value in order to reduce computation cost (Fixed Depth).
  • Figure 5: Distribution of router weights for the last repeat for different number of training steps. It can be seen that when training longer, the model learns more to use the deepest repeat, leading to higher router weights.