Table of Contents
Fetching ...

Domain Mixture Design via Log-Likelihood Differences for Aligning Language Models with a Target Model

Ryo Kishino, Riku Shiomi, Hiroaki Yamagiwa, Momose Oyama, Hidetoshi Shimodaira

Abstract

Instead of directly distilling a language model, this study addresses the problem of aligning a base model with a target model in distribution by designing the domain mixture of training data for pretraining or continued pretraining as a fixed training recipe. We propose a method for determining domain weights by viewing models as points in log-likelihood space and aligning the training update direction with the direction toward the target model. Experiments with NanoGPT show that the proposed method consistently reduces the KL divergence to the target model compared with uniform weighting over the Pile. Although knowledge distillation remains more effective when available, the proposed method still achieves meaningful alignment, and downstream task performance also tends to become closer to that of the target model.

Domain Mixture Design via Log-Likelihood Differences for Aligning Language Models with a Target Model

Abstract

Instead of directly distilling a language model, this study addresses the problem of aligning a base model with a target model in distribution by designing the domain mixture of training data for pretraining or continued pretraining as a fixed training recipe. We propose a method for determining domain weights by viewing models as points in log-likelihood space and aligning the training update direction with the direction toward the target model. Experiments with NanoGPT show that the proposed method consistently reduces the KL divergence to the target model compared with uniform weighting over the Pile. Although knowledge distillation remains more effective when available, the proposed method still achieves meaningful alignment, and downstream task performance also tends to become closer to that of the target model.
Paper Structure (32 sections, 48 equations, 10 figures, 4 tables, 1 algorithm)

This paper contains 32 sections, 48 equations, 10 figures, 4 tables, 1 algorithm.

Figures (10)

  • Figure 1: t-SNE visualization of training trajectories in the log-likelihood vector space during continued pretraining from pretrained NanoGPT toward Gemma-2B. The trajectories are obtained using either uniform domain weights over the Pile domains or weights determined by the proposed method (aggregated-LLD). Each point represents a checkpoint taken every 1k training steps, and color intensity indicates the number of training steps. The star denotes the target model. Compared with uniform weighting, the proposed method yields a trajectory that moves more consistently toward the target model. See Section \ref{['subsec:ft-kl-target']} for details.
  • Figure 2: t-SNE visualization of NanoGPT training trajectories for two training datasets independently sampled from each of two domain mixtures over the Pile. Each point represents a checkpoint taken every 1k training steps, and color intensity indicates the number of training steps. Differences in trajectory are much larger across domain mixtures than across datasets sampled from the same mixture. The two mixtures are obtained by doubling the weights of either the code-related domains or the PubMed-related domains relative to the original Pile weights. See Appendix \ref{['app:experimental-setup-detail']} for details.
  • Figure 3: Illustration of the update direction in log-likelihood space when the base model $p_{\bm\theta}$ is updated along the gradient direction in parameter space using a text sampled from the mixture distribution $r_{\bm\pi}$. Our goal is to estimate $\bm\pi$ so that the blue update direction aligns with the red direction toward the target model.
  • Figure 4: Comparison of domain weights estimated by aggregated-LLD and uniform weights against the ground-truth. The target model is CodeNanoGPT, pretrained with known domain weights, and the base model is a randomly initialized NanoGPT. The KL values in the legend denote the KL divergence from the ground-truth domain weights.
  • Figure 5: KL divergence to the target model over training steps when random-init/pretrained NanoGPT is trained with different domain weights.
  • ...and 5 more figures