Table of Contents
Fetching ...

Understanding the Role of Training Data in Test-Time Scaling

Adel Javanmard, Baharan Mirzasoleiman, Vahab Mirrokni

TL;DR

This work analyzes how training-data properties shape the benefits of test-time scaling for chain-of-thought reasoning in transformers. Focusing on in-context weight prediction for linear regression, it shows that test-time CoT approximates a multi-step (pseudo-)Newton method and yields a global optimum under gradient-descent training of a one-layer linear self-attention, with a closed-form characterization involving a regularized covariance operator $\Gamma$. A hardness measure based on the feature-covariance spectrum is introduced, and the authors prove that, at fixed downstream error, longer CoT can reduce the required training context, while underrepresented skills can cause overthinking; they also derive a tractable quadratic program to optimize task selection, endorsing diverse, relevant, and hard training tasks. Empirical validation on both linear and nonlinear transformers (LSA and GPT-2) confirms the scaling laws and the benefit of carefully chosen task mixtures, highlighting implications for data curation and inference-time computation. The results illuminate when test-time thinking helps or hurts and provide principled guidelines for designing training curricula to maximize test-time gains in reasoning tasks.

Abstract

Test-time scaling improves the reasoning capabilities of large language models (LLMs) by allocating extra compute to generate longer Chains-of-Thoughts (CoTs). This enables models to tackle more complex problem by breaking them down into additional steps, backtracking, and correcting mistakes. Despite its strong performance--demonstrated by OpenAI's o1 and DeepSeek R1, the conditions in the training data under which long CoTs emerge, and when such long CoTs improve the performance, remain unclear. In this paper, we study the performance of test-time scaling for transformers trained on an in-context weight prediction task for linear regression. Our analysis provides a theoretical explanation for several intriguing observations: First, at any fixed test error, increasing test-time compute allows us to reduce the number of in-context examples (context length) in training prompts. Second, if the skills required to solve a downstream task are not sufficiently present in the training data, increasing test-time compute can harm performance. Finally, we characterize task hardness via the smallest eigenvalue of its feature covariance matrix and show that training on a diverse, relevant, and hard set of tasks results in best performance for test-time scaling. We confirm our findings with experiments on large, nonlinear transformer architectures.

Understanding the Role of Training Data in Test-Time Scaling

TL;DR

This work analyzes how training-data properties shape the benefits of test-time scaling for chain-of-thought reasoning in transformers. Focusing on in-context weight prediction for linear regression, it shows that test-time CoT approximates a multi-step (pseudo-)Newton method and yields a global optimum under gradient-descent training of a one-layer linear self-attention, with a closed-form characterization involving a regularized covariance operator . A hardness measure based on the feature-covariance spectrum is introduced, and the authors prove that, at fixed downstream error, longer CoT can reduce the required training context, while underrepresented skills can cause overthinking; they also derive a tractable quadratic program to optimize task selection, endorsing diverse, relevant, and hard training tasks. Empirical validation on both linear and nonlinear transformers (LSA and GPT-2) confirms the scaling laws and the benefit of carefully chosen task mixtures, highlighting implications for data curation and inference-time computation. The results illuminate when test-time thinking helps or hurts and provide principled guidelines for designing training curricula to maximize test-time gains in reasoning tasks.

Abstract

Test-time scaling improves the reasoning capabilities of large language models (LLMs) by allocating extra compute to generate longer Chains-of-Thoughts (CoTs). This enables models to tackle more complex problem by breaking them down into additional steps, backtracking, and correcting mistakes. Despite its strong performance--demonstrated by OpenAI's o1 and DeepSeek R1, the conditions in the training data under which long CoTs emerge, and when such long CoTs improve the performance, remain unclear. In this paper, we study the performance of test-time scaling for transformers trained on an in-context weight prediction task for linear regression. Our analysis provides a theoretical explanation for several intriguing observations: First, at any fixed test error, increasing test-time compute allows us to reduce the number of in-context examples (context length) in training prompts. Second, if the skills required to solve a downstream task are not sufficiently present in the training data, increasing test-time compute can harm performance. Finally, we characterize task hardness via the smallest eigenvalue of its feature covariance matrix and show that training on a diverse, relevant, and hard set of tasks results in best performance for test-time scaling. We confirm our findings with experiments on large, nonlinear transformer architectures.

Paper Structure

This paper contains 23 sections, 12 theorems, 84 equations, 4 figures.

Key Result

Theorem 3.1

Consider the linear self-attention network over the population loss eq:pop-loss with initialization for some real-valued $c$. Also define We run gradient descent on the population loss with constant step size $\eta \le 1/(c^2\left\|\Gamma\right\|_{\rm op})$. We also fix $W_{24}(t) = -c$ . The gradient descent converges to a global minimum of the loss given by

Figures (4)

  • Figure 1: Test-time scaling for the in-context learning. Here, $n$ is the number of in-context examples (context length) in training prompts, and $H$ is the task hardness.
  • Figure 2: More test-time compute reduces training-time requirements for (a) one-layer transformer and (c) GPT-2. However, insufficient task coverage in training data makes longer CoTs harmful for (b) one-layer transformer and (d) GPT-2. For GPT-2, the errorbars are std of 10 runs. For LSA, std is negligible as we start from the fixed initialization in Eq. \ref{['eq:Vs-Ws']}.
  • Figure 3: Task selection in a multi-task setup (a) Each color corresponds to a task type with solid lines indicating the average selection probability per type. As we observe harder and more diverse tasks receive higher selection probabilities, while easier, more concentrated tasks are weighted lower (b) Task selection probabilities versus task hardness. As we see harder task are favored in the selection.
  • Figure 4: Transformer with a single linear self-attention. (a), (b) Fixing the test error, by increasing $k$, we can decrease the length of prompts $n$ during training. (c) When some directions of test are not enough represented in training data, more test-time compute hurst the performance.

Theorems & Definitions (15)

  • Theorem 3.1
  • Proposition 3.2
  • Remark 3.3
  • Theorem 3.3
  • Theorem 3.4
  • Corollary 3.5
  • Remark 3.5
  • Theorem 4.1
  • Proposition 4.2
  • Remark 4.1
  • ...and 5 more