CoTFormer: A Chain-of-Thought Driven Architecture with Budget-Adaptive Computation Cost at Inference
Amirkeivan Mohtashami, Matteo Pagliardini, Martin Jaggi
TL;DR
The paper introduces CoTFormer, a chain-of-thought-inspired transformer that attends to past intermediate representations, effectively simulating token-level depth without explicit deep weight-tying. It demonstrates that CoTFormer outperforms Block Universal Transformer on perplexity while offering favorable compute characteristics, moving the perplexity-compute Pareto frontier forward. To enable inference-time budget control, the authors develop Mixture of Repeats and an adaptive training scheme with depth embeddings and reserved layers, achieving substantial compute savings with minimal accuracy loss. The work provides detailed ablations and comparisons, and outlines practical directions for longer training and improved adaptive sampling to further close the gap to much larger models.
Abstract
Scaling language models to larger and deeper sizes has led to significant boosts in performance. Even though the size of these models limits their application in compute-constrained environments, the race to continually develop ever larger and deeper foundational models is underway. At the same time -- regardless of the model size -- task-specific techniques continue to play a pivotal role in achieving optimal downstream performance. One of these techniques, called Chain-of-Thought (CoT), is particularly interesting since, as we point out in this work, it resembles employing a deeper transformer through re-applying the model multiple times. However, a key subtlety in computing the attention of past tokens differentiates CoT from simply applying the model several times. Based on this insight, we propose CoTFormer, a novel architecture which closely mimics CoT at the token level, allowing us to obtain significantly improved accuracies close to much larger models. While applying CoT introduces additional computation costs, we compensate for it by leveraging CoTFormer's special compatibility with token-wise variable depth. Through a compute adaptive model -- which automatically allocates the compute to tokens that need it most -- we show that it is possible to reduce the computation cost significantly without any reduction in accuracy, and with further compute cost reductions possible while maintaining a competitive accuracy.
