Table of Contents
Fetching ...

Sampling and Loss Weights in Multi-Domain Training

Mahdi Salmani, Pratik Worah, Meisam Razaviyayn, Vahab Mirrokni

TL;DR

The paper investigates how multi-domain data should be weighted during training, arguing that a single scalar domain weight conflates two distinct roles: loss weights that shape the optimization objective and sampling weights that affect gradient variance during SGD. It develops practical estimators for both: One-shot FGLS for loss weights in linear regression, ERMA weighting for general models, and Variance-Aware (VA) sampling for gradient estimation stability. Through linear, logistic, and neural-net experiments, it demonstrates that loss weights and sampling weights provide complementary improvements, with combined use yielding additional gains in estimation and optimization. This two-dimensional weighting framework offers clearer theory and actionable guidance for balancing data quality and optimization dynamics in large-scale pretraining pipelines.

Abstract

In the training of large deep neural networks, there is a need for vast amounts of training data. To meet this need, data is collected from multiple domains, such as Wikipedia and GitHub. These domains are heterogeneous in both data quality and the diversity of information they provide. This raises the question of how much we should rely on each domain. Several methods have attempted to address this issue by assigning sampling weights to each data domain using heuristics or approximations. As a first step toward a deeper understanding of the role of data mixing, this work revisits the problem by studying two kinds of weights: sampling weights, which control how much each domain contributes in a batch, and loss weights, which scale the loss from each domain during training. Through a rigorous study of linear regression, we show that these two weights play complementary roles. First, they can reduce the variance of gradient estimates in iterative methods such as stochastic gradient descent (SGD). Second, they can improve generalization performance by reducing the generalization gap. We provide both theoretical and empirical support for these claims. We further study the joint dynamics of sampling weights and loss weights, examining how they can be combined to capture both contributions.

Sampling and Loss Weights in Multi-Domain Training

TL;DR

The paper investigates how multi-domain data should be weighted during training, arguing that a single scalar domain weight conflates two distinct roles: loss weights that shape the optimization objective and sampling weights that affect gradient variance during SGD. It develops practical estimators for both: One-shot FGLS for loss weights in linear regression, ERMA weighting for general models, and Variance-Aware (VA) sampling for gradient estimation stability. Through linear, logistic, and neural-net experiments, it demonstrates that loss weights and sampling weights provide complementary improvements, with combined use yielding additional gains in estimation and optimization. This two-dimensional weighting framework offers clearer theory and actionable guidance for balancing data quality and optimization dynamics in large-scale pretraining pipelines.

Abstract

In the training of large deep neural networks, there is a need for vast amounts of training data. To meet this need, data is collected from multiple domains, such as Wikipedia and GitHub. These domains are heterogeneous in both data quality and the diversity of information they provide. This raises the question of how much we should rely on each domain. Several methods have attempted to address this issue by assigning sampling weights to each data domain using heuristics or approximations. As a first step toward a deeper understanding of the role of data mixing, this work revisits the problem by studying two kinds of weights: sampling weights, which control how much each domain contributes in a batch, and loss weights, which scale the loss from each domain during training. Through a rigorous study of linear regression, we show that these two weights play complementary roles. First, they can reduce the variance of gradient estimates in iterative methods such as stochastic gradient descent (SGD). Second, they can improve generalization performance by reducing the generalization gap. We provide both theoretical and empirical support for these claims. We further study the joint dynamics of sampling weights and loss weights, examining how they can be combined to capture both contributions.

Paper Structure

This paper contains 32 sections, 9 theorems, 71 equations, 6 figures, 1 table, 2 algorithms.

Key Result

Theorem 3.1

Consider the linear model $\mathbf{y} = \mathbf{X}\theta + \mathbf{\epsilon}$, where $\mathbb{E}[\mathbf{\epsilon}] = 0$ and $\operatorname{Var}(\mathbf{\epsilon}) = \mathbf{\Sigma}$, with $\mathbf{\Sigma}$ a positive definite matrix. The generalized least squares (GLS) estimator is the best linear unbiased estimator, achieving the minimum variance among linear unbiased estimators.

Figures (6)

  • Figure 1: Performance of different methods in the linear regression example. Figures a to c correspond to $(C_1, C_2) = (100, 1)$, while Figures d to f correspond to $(C_1, C_2) = (1, 100)$. a, d: Distance between the estimated parameter and the ground-truth $\theta_{\mathrm{gt}}$ for each method. b, e: Evolution of loss weights for domain one during training. c, f: Evolution of sampling weights for domain one during training.
  • Figure 2: Performance of different methods in the logistic regression example. Figures a to c correspond to $(C_1, C_2) = (100, 100)$, while Figures d to f correspond to $(C_1, C_2) = (10, 100)$. a, d: Cosine distance between the estimated parameter and the ground-truth $\theta_{\mathrm{gt}}$ for each method. b, e: Evolution of loss weights for domain one during training. c, f: Evolution of sampling weights for domain one during training.
  • Figure 3: Performance of different methods in the neural net example.
  • Figure 4: Performance of different methods in the linear regression example. Figures a to c correspond to $(C_1, C_2) = (100, 1)$ and $(\sigma_1^2, \sigma_2^2) = (1,1)$, while Figures d to f correspond to $(C_1, C_2) = (1, 1)$ and $(\sigma_1^2, \sigma_2^2) = (1,20)$. a, d: Distance between the estimated parameter and the ground-truth $\theta_{\mathrm{gt}}$ for each method. b, e: Evolution of loss weights for domain one during training. c, f: Evolution of sampling weights for domain one during training.
  • Figure 5: Performance of different methods in the logistic regression example under accuracy. Figure (a) corresponds to $(C_1, C_2) = (100, 100)$, while Figure (b) corresponds to $(C_1, C_2) = (10, 100)$.
  • ...and 1 more figures

Theorems & Definitions (14)

  • Theorem 3.1: aitken1935least
  • Corollary 3.2
  • Theorem 3.3: Informal
  • Theorem 3.4: Informal
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Lemma A.3
  • proof
  • ...and 4 more