Table of Contents
Fetching ...

Beyond Next Token Prediction: Patch-Level Training for Large Language Models

Chenze Shao, Fandong Meng, Jie Zhou

TL;DR

This work introduces patch-level training for large language models, where $K$ consecutive tokens are aggregated into a patch and the model learns to predict the next patch, significantly reducing training compute. The method uses a two-stage approach: patch-level pretraining followed by token-level finetuning, with a transfer from patch-level parameters to the token-level model. Across 370M–2.7B parameter Transformers trained on the Pile, patch-level training achieves about $0.5\times$ the original compute while maintaining perplexity and often improving zero-shot and instruction-following performance. The findings suggest patch-level training can dramatically improve training efficiency, with insights on scaling, patch size, data fraction, architecture, and neuron activation, and point to promising future directions for scaling laws and multi-epoch training.

Abstract

The prohibitive training costs of Large Language Models (LLMs) have emerged as a significant bottleneck in the development of next-generation LLMs. In this paper, we show that it is possible to significantly reduce the training costs of LLMs without sacrificing their performance. Specifically, we introduce patch-level training for LLMs, in which multiple tokens are aggregated into a unit of higher information density, referred to as a `patch', to serve as the fundamental text unit for training LLMs. During patch-level training, we feed the language model shorter sequences of patches and train it to predict the next patch, thereby processing the majority of the training data at a significantly reduced cost. Following this, the model continues token-level training on the remaining training data to align with the inference mode. Experiments on a diverse range of models (370M-2.7B parameters) demonstrate that patch-level training can reduce the overall training costs to 0.5$\times$, without compromising the model performance compared to token-level training. Source code: https://github.com/shaochenze/PatchTrain.

Beyond Next Token Prediction: Patch-Level Training for Large Language Models

TL;DR

This work introduces patch-level training for large language models, where consecutive tokens are aggregated into a patch and the model learns to predict the next patch, significantly reducing training compute. The method uses a two-stage approach: patch-level pretraining followed by token-level finetuning, with a transfer from patch-level parameters to the token-level model. Across 370M–2.7B parameter Transformers trained on the Pile, patch-level training achieves about the original compute while maintaining perplexity and often improving zero-shot and instruction-following performance. The findings suggest patch-level training can dramatically improve training efficiency, with insights on scaling, patch size, data fraction, architecture, and neuron activation, and point to promising future directions for scaling laws and multi-epoch training.

Abstract

The prohibitive training costs of Large Language Models (LLMs) have emerged as a significant bottleneck in the development of next-generation LLMs. In this paper, we show that it is possible to significantly reduce the training costs of LLMs without sacrificing their performance. Specifically, we introduce patch-level training for LLMs, in which multiple tokens are aggregated into a unit of higher information density, referred to as a `patch', to serve as the fundamental text unit for training LLMs. During patch-level training, we feed the language model shorter sequences of patches and train it to predict the next patch, thereby processing the majority of the training data at a significantly reduced cost. Following this, the model continues token-level training on the remaining training data to align with the inference mode. Experiments on a diverse range of models (370M-2.7B parameters) demonstrate that patch-level training can reduce the overall training costs to 0.5, without compromising the model performance compared to token-level training. Source code: https://github.com/shaochenze/PatchTrain.
Paper Structure (17 sections, 3 equations, 11 figures, 7 tables)

This paper contains 17 sections, 3 equations, 11 figures, 7 tables.

Figures (11)

  • Figure 1: Visualization of overall training costs with patch compression for a fraction $\lambda$ of training data and patch size $K$.
  • Figure 2: Negative log-likelihood (NLL) loss on test set w.r.t the number of processed tokens during the training of 370M-parameter Transformers.
  • Figure 3: Overview of patch-level training. Every consecutive $K$ token embeddings are averaged to form the patch embedding. The sequence model is fed the patch sequence and trained to predict the next patch. The cross-entropy loss is computed based on each patch prediction vector and all the subsequent $K$ tokens in its next patch.
  • Figure 4: Instruction-following abilities evaluated on MT-bench, a multi-turn question set.
  • Figure 5: Test losses of Transformer-370M w.r.t the number of processed tokens. Models are initialized by patch-level training with patch size $K$.
  • ...and 6 more figures