Table of Contents
Fetching ...

A Simple Model of Inference Scaling Laws

Noam Levi

TL;DR

This work investigates how model performance scales during inference as the number of attempts increases. It proposes a simple memorization-based framework to predict pass@k and an inference-loss, deriving analytic forms under both iid and correlated-trial assumptions and linking these to total inference cost via FLOPS. The key insight is that inference quality can exhibit a power-law improvement with more attempts, governed by task difficulty parameters, and that this behavior can be captured in a universal, model-agnostic way. Empirical validation on a VAE reconstruction task supports the proposed scaling forms and demonstrates how inference dynamics might be integrated with broader neural scaling laws.

Abstract

Neural scaling laws have garnered significant interest due to their ability to predict model performance as a function of increasing parameters, data, and compute. In this work, we propose a simple statistical ansatz based on memorization to study scaling laws in the context of inference, specifically how performance improves with multiple inference attempts. We explore the coverage, or pass@k metric, which measures the chance of success over repeated attempts and provide a motivation for the observed functional form of the inference scaling behavior of the coverage in large language models (LLMs) on reasoning tasks. We then define an "inference loss", which exhibits a power law decay as the number of trials increases, and connect this result with prompting costs. We further test our construction by conducting experiments on a simple generative model, and find that our predictions are in agreement with the empirical coverage curves in a controlled setting. Our simple framework sets the ground for incorporating inference scaling with other known scaling laws.

A Simple Model of Inference Scaling Laws

TL;DR

This work investigates how model performance scales during inference as the number of attempts increases. It proposes a simple memorization-based framework to predict pass@k and an inference-loss, deriving analytic forms under both iid and correlated-trial assumptions and linking these to total inference cost via FLOPS. The key insight is that inference quality can exhibit a power-law improvement with more attempts, governed by task difficulty parameters, and that this behavior can be captured in a universal, model-agnostic way. Empirical validation on a VAE reconstruction task supports the proposed scaling forms and demonstrates how inference dynamics might be integrated with broader neural scaling laws.

Abstract

Neural scaling laws have garnered significant interest due to their ability to predict model performance as a function of increasing parameters, data, and compute. In this work, we propose a simple statistical ansatz based on memorization to study scaling laws in the context of inference, specifically how performance improves with multiple inference attempts. We explore the coverage, or pass@k metric, which measures the chance of success over repeated attempts and provide a motivation for the observed functional form of the inference scaling behavior of the coverage in large language models (LLMs) on reasoning tasks. We then define an "inference loss", which exhibits a power law decay as the number of trials increases, and connect this result with prompting costs. We further test our construction by conducting experiments on a simple generative model, and find that our predictions are in agreement with the empirical coverage curves in a controlled setting. Our simple framework sets the ground for incorporating inference scaling with other known scaling laws.

Paper Structure

This paper contains 12 sections, 16 equations, 7 figures.

Figures (7)

  • Figure 1: Pass@k and failure distribution curves for various LLMs on difficult tasks, against theoretical scaling predictions. Left: The relationship between pass@k and the number of samples for several coding and maths tasks for different models, as described in brown2024largelanguagemonkeysscaling, compared with the analytical predictions presented in \ref{['eq:main_result', 'eq:pass@k']}. The solid curves are data, while the dashed curves are the predictions from \ref{['eq:main_result']}, where $\alpha$ and $\beta$ correspond to the concentration of easy and hard problems, respectively. The dotted curves are the results of \ref{['eq:pass@k']}. The functional form in both cases captures well the LLM pass@k curves for various models, by adjusting $\alpha,\beta$ or $p,\kappa$. Right: The $\text{Beta}(\alpha,\beta)$ distributions for the failure probabilities are shown for the different models. We can see that most of the questions are "difficult", while the existence of a left tail implies that more trials are required to obtain a correct answer.
  • Figure 2: Inference attempts loss $\mathcal{L}_{\text{inference}}(k)$ for repeated attempts on the memorizing model. Top: Inference loss as a function of trials for the LLM experiments in brown2024largelanguagemonkeysscaling. Bottom: Inference loss for different $\beta$ and $k$ values. Different colors indicate inference loss values at fixed $\alpha=5$ ( left) and at fixed $k=10^4$ ( right), illustrating the behavior of \ref{['eq:loss']}.
  • Figure 3: Pass@k as a function of total inference cost for Llama-3-8B MATH (Oracle Verifier). Left and Center: We show the pass@k metric as a function of number of total inference cost and number of FLOPS per token $F$ or number of prompt/decode tokens $N_p=N_d$ in $\log,\log$. We see that there is a clear trade-off between total inference cost whenever keeping one of the parameters fixed, in a predictable way from \ref{['eq:cost']}. Right: We show a slice of the contour plots for fixed $N_p=N_d$, and changing the number of FLOPS per token. The parameters chosen for these figures are fitted from \ref{['eq:main_result']} applied to the data taken from brown2024largelanguagemonkeysscaling.
  • Figure 4: Visualization of the task described in \ref{['sec:experiments']}. Here, a VAE is tasked with generating samples from its training data, where a "failure" occurs when the reconstruction error falls above a certain threshold $\epsilon$.
  • Figure 5: Results for the VAE reconstruction task, compared with semi-analytical predictions. Left: The pass@k metric as a function of number of attempts $k$, for different threshold values, with temperature $T=1.1$. The curves have been normalized to asymptote at 1 for visual clarity. Right: The reconstruction error behavior across multiple trials, indicated by different colors. The errors obey a quasi power law behavior.
  • ...and 2 more figures