Table of Contents
Fetching ...

Efficient Training of Language Models with Compact and Consistent Next Token Distributions

Ashutosh Sathe, Sunita Sarawagi

TL;DR

This paper tackles the inefficiency of corpus-level $n$-gram regularization in language-model pre-training by introducing CoCoNTs, a compact and consistent surrogate for the next-token distribution that can be precomputed and stored with the data. By truncating the $n$-gram distribution to a Top-$r$ set and carefully transferring mass to rarer tokens, CoCoNTs preserves the statistical intent of AllNTs while reducing variance and avoiding expensive runtime queries. The approach yields comparable or better model quality than AllNTs with substantially lower training time and fixed storage overhead, enabling scalable gains for both pre-training and fine-tuning, including PEFT scenarios and domain-specific datasets. The results show faster convergence, improved efficiency, and robust performance across multiple models and tasks, with practical implications for large-scale LLM training and deployment.

Abstract

Maximizing the likelihood of the next token is an established, statistically sound objective for pre-training language models. In this paper we show that we can train better models faster by pre-aggregating the corpus with a collapsed $n$-gram distribution. Previous studies have proposed corpus-level $n$-gram statistics as a regularizer; however, the construction and querying of such $n$-grams, if done naively, prove to be costly and significantly impede training speed, thereby limiting their application in modern large language model pre-training. We introduce an alternative compact representation of the next token distribution that, in expectation, aligns with the complete $n$-gram distribution while markedly reducing variance across mini-batches compared to the standard next-token loss. Empirically, we demonstrate that both the $n$-gram regularized model and our approximation yield substantial improvements in model quality and convergence rate compared to existing methods. Furthermore, our approximation facilitates scalability of gains to larger datasets and models compared to the straightforward $n$-gram regularization method.

Efficient Training of Language Models with Compact and Consistent Next Token Distributions

TL;DR

This paper tackles the inefficiency of corpus-level -gram regularization in language-model pre-training by introducing CoCoNTs, a compact and consistent surrogate for the next-token distribution that can be precomputed and stored with the data. By truncating the -gram distribution to a Top- set and carefully transferring mass to rarer tokens, CoCoNTs preserves the statistical intent of AllNTs while reducing variance and avoiding expensive runtime queries. The approach yields comparable or better model quality than AllNTs with substantially lower training time and fixed storage overhead, enabling scalable gains for both pre-training and fine-tuning, including PEFT scenarios and domain-specific datasets. The results show faster convergence, improved efficiency, and robust performance across multiple models and tasks, with practical implications for large-scale LLM training and deployment.

Abstract

Maximizing the likelihood of the next token is an established, statistically sound objective for pre-training language models. In this paper we show that we can train better models faster by pre-aggregating the corpus with a collapsed -gram distribution. Previous studies have proposed corpus-level -gram statistics as a regularizer; however, the construction and querying of such -grams, if done naively, prove to be costly and significantly impede training speed, thereby limiting their application in modern large language model pre-training. We introduce an alternative compact representation of the next token distribution that, in expectation, aligns with the complete -gram distribution while markedly reducing variance across mini-batches compared to the standard next-token loss. Empirically, we demonstrate that both the -gram regularized model and our approximation yield substantial improvements in model quality and convergence rate compared to existing methods. Furthermore, our approximation facilitates scalability of gains to larger datasets and models compared to the straightforward -gram regularization method.
Paper Structure (19 sections, 6 equations, 7 figures, 5 tables)

This paper contains 19 sections, 6 equations, 7 figures, 5 tables.

Figures (7)

  • Figure 1: Comparison of various training methods. Standard next-token likelihood reads inputs as well as targets from the disk. $n$-gram augmented methods (AllNTs) obtain targets ($\mathbf{y}^\text{all}_{i}$) by querying an $n$-gram model which can be slow during training. Our proposed method, CoCoNTs, truncates and approximates the $\mathbf{y}^\text{all}_{i}$ and stores the preprocessed distribution ($\mathbf{y}^\text{CC}_{i}$) along with the dataset itself for faster retrieval during training.
  • Figure 2: Comparison of training efficiency on WikiText-103 (top) and PubMed (bottom). AllNTs with higher values of $k$ can easily go out of memory from a naive implementation. Both AllNTs and CoCoNTs converge faster to NTL's validation perplexity as compared to NTL. The total wall time (TWT) to finish the entire training is also significantly lower with CoCoNTs as compared to AllNTs due to lack of any $n$-gram querying during training. gpt2-125m model is used for all experiments with $r=8$ for CoCoNTs.
  • Figure 3: Effect of sharded CoCoNTs on large datasets. Oversharding can make the $n$-gram distribution unreasonably sparse. This can lead to overly optimistic approximation and KL penalty which can hurt the performance on extremely small indices.
  • Figure 4: Ablations studies on $k$ and $r$ for AllNTs and CoCoNTs. All experiments fine-tune a gpt2-125m model on WikiText-103. Higher values of both $k$ and $r$ improve perplexity before plateauing. $r=8$ is fixed when varying $k$ and $k=8$ is fixed when varying $r$ for CoCoNTs.
  • Figure 5: Perplexity as a function of position in sequence. Both AllNTs and CoCoNTs show smooth changes in perplexity despite applying loss to only a small $k$ token prefix.
  • ...and 2 more figures