Table of Contents
Fetching ...

Thinking Slow, Fast: Scaling Inference Compute with Distilled Reasoners

Daniele Paliotta, Junxiong Wang, Matteo Pagliardini, Kevin Y. Li, Aviv Bick, J. Zico Kolter, Albert Gu, François Fleuret, Tri Dao

TL;DR

This work investigates whether subquadratic, faster-to-infer architectures can surpass Transformer teachers when inference-time compute is scaled for reasoning tasks. It introduces distillation pipelines that transfer reasoning capabilities from pretrained Transformers into pure Mamba (Llamba) and hybrid MambaInLlama models, trained on 8B-token datasets. Under fixed time budgets, the distilled models exhibit faster generation, broader coverage, and competitive or superior accuracy on math reasoning benchmarks (MATH and GSM8K), with additional gains from supervised fine-tuning. The results suggest that inference compute scaling with distilled subquadratic reasoners can push the Pareto front beyond Transformer teachers, offering a practical path to efficient, scalable reasoning systems.

Abstract

Recent advancements have demonstrated that the performance of large language models (LLMs) can be significantly enhanced by scaling computational resources at test time. A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms. This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget? To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers. Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences. Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teacher models under fixed time budgets, opening a new direction for scaling inference compute.

Thinking Slow, Fast: Scaling Inference Compute with Distilled Reasoners

TL;DR

This work investigates whether subquadratic, faster-to-infer architectures can surpass Transformer teachers when inference-time compute is scaled for reasoning tasks. It introduces distillation pipelines that transfer reasoning capabilities from pretrained Transformers into pure Mamba (Llamba) and hybrid MambaInLlama models, trained on 8B-token datasets. Under fixed time budgets, the distilled models exhibit faster generation, broader coverage, and competitive or superior accuracy on math reasoning benchmarks (MATH and GSM8K), with additional gains from supervised fine-tuning. The results suggest that inference compute scaling with distilled subquadratic reasoners can push the Pareto front beyond Transformer teachers, offering a practical path to efficient, scalable reasoning systems.

Abstract

Recent advancements have demonstrated that the performance of large language models (LLMs) can be significantly enhanced by scaling computational resources at test time. A common strategy involves generating multiple Chain-of-Thought (CoT) trajectories and aggregating their outputs through various selection mechanisms. This raises a fundamental question: can models with lower complexity leverage their superior generation throughput to outperform similarly sized Transformers for a fixed computational budget? To address this question and overcome the lack of strong subquadratic reasoners, we distill pure and hybrid Mamba models from pretrained Transformers. Trained on only 8 billion tokens, our distilled models show strong performance and scaling on mathematical reasoning datasets while being much faster at inference for large batches and long sequences. Despite the zero-shot performance hit due to distillation, both pure and hybrid Mamba models can scale their coverage and accuracy performance past their Transformer teacher models under fixed time budgets, opening a new direction for scaling inference compute.

Paper Structure

This paper contains 21 sections, 1 equation, 12 figures, 1 algorithm.

Figures (12)

  • Figure 1: Distilled models have better coverage on MATH for most time budgets. In (b), we show the coverage as we increase the number of sampled answers $k$. Compared to their associated Llama baselines, our distilled models provide a lower coverage for a given $k$. In (a), we now show the shortest time required to reach a given coverage. For each curve in (b), we map the $k$-values on the x-axis to the time required to generate that many samples for each model. Ideally, we would want to reach the highest coverage for short time budget. For a given time budget, our distilled models can generate many more completions than their respective baselines. As such, the higher throughput of our models, shown in Figure \ref{['fig:speedup']}, allows them to overcome their lower per-sample coverage. As a result, our models push the Pareto front forward for most time budgets.
  • Figure 2: Faster generation of distilled models. In (a) and (b), we show the inference time measured for the baseline Llama models as well as our distilled Llamba (pure Mamba) and MambaInLlama (hybrid) models at the 1B and 3B scale. We denote the speedup for our MambaInLlama model at each batch size. We use prompts of $512$ tokens and measure the time required to generate $512$ tokens. The times measured do not include the prefilling of the prompt. Overall, distilled models can generate tokens much faster with the speedup being greater for larger batch sizes. Moreover, our distilled models are more memory efficient, as shown in (b), using a batch size of $512$ yields an Out of Memory (OOM) error for Llama 3B, but not for our models. To obtain $512$ completions with Llama-3B, the two batches of $256$, result in an inference time of $58.8$s. In comparison, our MambaInLlama model would take $11.6$s, a speedup of $\times5.1$.
  • Figure 3: Negligible effect of finetuning Llama baselines on distillation dataset. As our distillation dataset includes math content, we also finetune Llama models on the distillation dataset OpenMathInstruct-2. Those models are marked with the "+FT" prefix. We plot the coverage (a) and majority voting accuracies (b) as a function of the number of completions. We observe that finetuning Llama models on the distillation dataset has a negligible effect on those metrics.
  • Figure 4: Distilled models provide better accuracies on MATH for most time budgets. Figures are similar to Figure \ref{['fig:coverage-math-a']}. In (a) and (b), we show the majority-voting accuracy and the weight Best-of-N accuracy (the selected answer is the one with the highest sum of reward model scores, as introduced in beeching2024scalingtesttimecompute), for different time budgets. Similarly to Figure \ref{['fig:coverage-math-a']}, we observe how the higher throughput of our distilled models allows them to push the Pareto front for most time budgets. Interestingly, when comparing models of a given size, Llama models are better for larger time budgets. However, looking at both model sizes together reveals that larger distilled models can compensate for the lower accuracies of smaller models. As a result, the Pareto front is defined by our hybrid models. While Llama-3B is still more efficient for a large time budget, one could imagine distilling a larger subquadratic model that generates quicker nonetheless.
  • Figure 5: Distilled models have better coverage on GSM8K for most time budgets. Observations are similar to Figure \ref{['fig:coverage-math']}, despite a small degradation in coverage per number of completions, the better throughput of distilled models pushes the Pareto front forward. Hybrid models are more efficient for most time budgets.
  • ...and 7 more figures