Table of Contents
Fetching ...

Dynamic Loss-Based Sample Reweighting for Improved Large Language Model Pretraining

Daouda Sow, Herbert Woisetschläger, Saikiran Bulusu, Shiqiang Wang, Hans-Arno Jacobsen, Yingbin Liang

TL;DR

This work tackles the inefficiency of uniform data sampling in large language model pretraining by introducing fully online, instance-level loss-based data reweighting. It develops a family of lightweight reweighting strategies (notably LinUpper) that modulate per-sample weights based on current losses, while preserving a bound on maximum weights to ensure stable convergence. A new theoretical framework characterizes how loss-based weighting affects gradient-based convergence bounds under convex interpolation, providing justification for down-weighting low-loss samples. Empirically, the methods yield faster convergence and improved performance across GPT-2 scale models and large Llama models, with stronger gains observed in larger models and when combined with domain-level reweighting (DoGE/DoReMi). The results suggest substantial data-efficiency benefits for LLM pretraining and point to practical, low-overhead integration into existing pipelines.

Abstract

Pretraining large language models (LLMs) on vast and heterogeneous datasets is crucial for achieving state-of-the-art performance across diverse downstream tasks. However, current training paradigms treat all samples equally, overlooking the importance or relevance of individual samples throughout the training process. Existing reweighting strategies, which primarily focus on group-level data importance, fail to leverage fine-grained instance-level information and do not adapt dynamically to individual sample importance as training progresses. In this paper, we introduce novel algorithms for dynamic, instance-level data reweighting aimed at improving both the efficiency and effectiveness of LLM pretraining. Our methods adjust the weight of each training sample based on its loss value in an online fashion, allowing the model to dynamically focus on more informative or important samples at the current training stage. In particular, our framework allows us to systematically devise reweighting strategies deprioritizing redundant or uninformative data, which we find tend to work best. Furthermore, we develop a new theoretical framework for analyzing the impact of loss-based reweighting on the convergence of gradient-based optimization, providing the first formal characterization of how these strategies affect convergence bounds. We empirically validate our approach across a spectrum of tasks, from pretraining 7B and 1.4B parameter LLMs to smaller-scale language models and linear regression problems, demonstrating that our loss-based reweighting approach can lead to faster convergence and significantly improved performance.

Dynamic Loss-Based Sample Reweighting for Improved Large Language Model Pretraining

TL;DR

This work tackles the inefficiency of uniform data sampling in large language model pretraining by introducing fully online, instance-level loss-based data reweighting. It develops a family of lightweight reweighting strategies (notably LinUpper) that modulate per-sample weights based on current losses, while preserving a bound on maximum weights to ensure stable convergence. A new theoretical framework characterizes how loss-based weighting affects gradient-based convergence bounds under convex interpolation, providing justification for down-weighting low-loss samples. Empirically, the methods yield faster convergence and improved performance across GPT-2 scale models and large Llama models, with stronger gains observed in larger models and when combined with domain-level reweighting (DoGE/DoReMi). The results suggest substantial data-efficiency benefits for LLM pretraining and point to practical, low-overhead integration into existing pipelines.

Abstract

Pretraining large language models (LLMs) on vast and heterogeneous datasets is crucial for achieving state-of-the-art performance across diverse downstream tasks. However, current training paradigms treat all samples equally, overlooking the importance or relevance of individual samples throughout the training process. Existing reweighting strategies, which primarily focus on group-level data importance, fail to leverage fine-grained instance-level information and do not adapt dynamically to individual sample importance as training progresses. In this paper, we introduce novel algorithms for dynamic, instance-level data reweighting aimed at improving both the efficiency and effectiveness of LLM pretraining. Our methods adjust the weight of each training sample based on its loss value in an online fashion, allowing the model to dynamically focus on more informative or important samples at the current training stage. In particular, our framework allows us to systematically devise reweighting strategies deprioritizing redundant or uninformative data, which we find tend to work best. Furthermore, we develop a new theoretical framework for analyzing the impact of loss-based reweighting on the convergence of gradient-based optimization, providing the first formal characterization of how these strategies affect convergence bounds. We empirically validate our approach across a spectrum of tasks, from pretraining 7B and 1.4B parameter LLMs to smaller-scale language models and linear regression problems, demonstrating that our loss-based reweighting approach can lead to faster convergence and significantly improved performance.

Paper Structure

This paper contains 28 sections, 7 theorems, 53 equations, 8 figures, 9 tables, 1 algorithm.

Key Result

Theorem 1

Consider $M$ data points and let each loss $f(\mathbf{x}_i; \cdot)$ be convex. Further, assume the interpolation regime holds, i.e., $\exists \theta^* \in \mathbb{R}^d$ such that $\theta^* \in \arg\min_{\theta \in \mathbb{R}^d} f(\mathbf{x}_i; \theta) ~ \forall i$. Then, for a reweighting scheme tha where $\delta^t = \sum_{i=1}^{M} \left(\frac{1}{M} - w(\mathbf{x}_i; \theta^t)\right)\left(f(\mathb

Figures (8)

  • Figure 1: Left: Geometric curves of the different reweighting functions. Right: Shape of the LinUpper strategy after applying \ref{['eq:softm']} on top of it for different values of $r$. These plots are obtained for a batch of 128 uniformly drawn losses. As $r$ increases, LinUpper converges to the uniform averaging method.
  • Figure 2: Per-domain perplexities on hold-out validation sets under the uniform domain sampling setting for the GPT2-medium model. Our reweighting strategy LinUpper strategy achieves better or at least comparable perplexity on 5 out of 7 domains.
  • Figure 3: Per-domain perplexities on hold-out validation sets under the uniform domain sampling setting for the GPT2-mini model. Our reweighting strategy LinUpper strategy achieves better or at least comparable perplexity on 6 out of 7 domains.
  • Figure 4: Per-domain perplexities on hold-out validation sets under the uniform domain sampling setting for the GPT2-small model. Our reweighting strategy LinUpper strategy achieves better or at least comparable perplexity on 6 out of 7 domains.
  • Figure 5: Sensitivity of our method LinUpper to the value of $r$. When the $r$ is large (e.g. $r=1$) the performance of our method becomes closer to that of the uniform baseline. Decreasing the value of $r$ leads to diminished effect of low-loss samples, but can also have a negative effect on the performance when $r$ is too small (eg. $r=0.2$)
  • ...and 3 more figures

Theorems & Definitions (13)

  • Theorem 1
  • Proposition 1
  • proof
  • Theorem 2: Minibatch SGD with momentum
  • proof
  • Lemma 1
  • proof
  • Proposition 1
  • proof
  • Theorem 3: Re-statement of \ref{['thm:minibatch_SGD_convex']}
  • ...and 3 more