Table of Contents
Fetching ...

Critical Data Size of Language Models from a Grokking Perspective

Xuekai Zhu, Yao Fu, Bowen Zhou, Zhouhan Lin

TL;DR

The paper investigates how data size governs the transition from memorization to generalization in language models by introducing the Critical Data Size (CDS) and the Data Efficiency Hypothesis. It develops a grokking configuration based on initialization rescaling and weight decay to reproduce data-dependent phase transitions across modular addition, IMDB, Yelp, and instruction-tuning tasks, and validates both sample-wise and model-wise grokking. The results show that CDS shifts upward with model size and reveal interpretable weight-norm dynamics across learning stages, underscoring the nuanced role of data quantity and regularization. Practically, the work offers data-pruning and initialization-control techniques to surface learning dynamics, informing how data and capacity should be balanced in real-language model training.

Abstract

We explore the critical data size in language models, a threshold that marks a fundamental shift from quick memorization to slow generalization. We formalize the phase transition under the grokking configuration into the Data Efficiency Hypothesis and identify data insufficiency, sufficiency, and surplus regimes in language models training dynamics. We develop a grokking configuration to reproduce grokking on simplistic language models stably by rescaling initialization and weight decay. We show that generalization occurs only when language models reach a critical size. We analyze grokking across sample-wise and model-wise, verifying the proposed data efficiency hypothesis. Our experiments reveal smoother phase transitions occurring at the critical dataset size for language datasets. As the model size increases, this critical point also becomes larger, indicating that larger models require more data. Our results deepen the understanding of language model training, offering a novel perspective on the role of data in the learning mechanism of language models.

Critical Data Size of Language Models from a Grokking Perspective

TL;DR

The paper investigates how data size governs the transition from memorization to generalization in language models by introducing the Critical Data Size (CDS) and the Data Efficiency Hypothesis. It develops a grokking configuration based on initialization rescaling and weight decay to reproduce data-dependent phase transitions across modular addition, IMDB, Yelp, and instruction-tuning tasks, and validates both sample-wise and model-wise grokking. The results show that CDS shifts upward with model size and reveal interpretable weight-norm dynamics across learning stages, underscoring the nuanced role of data quantity and regularization. Practically, the work offers data-pruning and initialization-control techniques to surface learning dynamics, informing how data and capacity should be balanced in real-language model training.

Abstract

We explore the critical data size in language models, a threshold that marks a fundamental shift from quick memorization to slow generalization. We formalize the phase transition under the grokking configuration into the Data Efficiency Hypothesis and identify data insufficiency, sufficiency, and surplus regimes in language models training dynamics. We develop a grokking configuration to reproduce grokking on simplistic language models stably by rescaling initialization and weight decay. We show that generalization occurs only when language models reach a critical size. We analyze grokking across sample-wise and model-wise, verifying the proposed data efficiency hypothesis. Our experiments reveal smoother phase transitions occurring at the critical dataset size for language datasets. As the model size increases, this critical point also becomes larger, indicating that larger models require more data. Our results deepen the understanding of language model training, offering a novel perspective on the role of data in the learning mechanism of language models.
Paper Structure (37 sections, 5 equations, 14 figures, 2 tables)

This paper contains 37 sections, 5 equations, 14 figures, 2 tables.

Figures (14)

  • Figure 1: Comprehensive analysis of training dynamics and accuracy curves verifies the data efficiency hypothesis on vanilla grokking power2022grokkingvarma2023explaining. A: Reproduced grokking phenomenon on modular addition using a 1-layer decoder-only Transformer trained on 2000 samples. Delayed generalization ($\approx100\%$ test acc) occurs during continuous training after memorization completion ($\approx100\%$ train acc, overfitting). B: Step-wise Analysis of Test Accuracy. We observe a clear peak indicating slow generalization at the critical data size, while more training samples markedly speed up generalization. Below the critical data size, no generalization happens. C: Step-wise Analysis of Training Accuracy. Within 400 steps, the model can memorize all training data. Across various dataset sizes, there is a very small difference in memorization steps. D: 1D PCA visualization of modular addition datasets. Data pruning uniformly samples from the initial distribution. E and F: Test / Training accuracy across the whole training process. The detailed training process is presented in Figure \ref{['fig:diff_data_size_on_modular']}.
  • Figure 2: A: We induce the grokking phenomenon on IMDB maas-EtAl:2011:ACL-HLT2011 using a 1-layer encoder-only Transformer. The model suddenly flipped from memorizing training data to generalizing unseen test data after training for much longer. B: Step-wise Analysis of Test Accuracy in IMDB. We can observe a growth of steps, moving from very weak generalization (the light dot) to full-fledged generalization (the dark dot). As data fraction increases, generalization steps rapidly reach and maintain stability. C: Step-wise Analysis of Training Accuracy in IMDB. The memorization steps gradually increase with the fraction of the dataset. And it can finally achieve 100% accuracy on the training data.
  • Figure 3: A: We employ a 1-layer, encoder-only Transformer to trigger the grokking phenomenon on $10\%$ Yelp data zhangCharacterlevelConvolutionalNetworks2015. The delayed generalization occurs after overfitting. B: Step-wise Analysis of Test Accuracy in Yelp. The generalization steps first increase and subsequently decrease as the data fraction grows, which is consistent with results on modular addition and IMDB datasets. C: Step-wise Analysis of Training Accuracy in Yelp. Similar to experiments of modular addition and IMDB datasets, we obtain the same conclusion: memorization steps increase as the dataset size expands. The detailed training process is presented in Figure \ref{['fig:acc_on_yelp']}.
  • Figure 4: Model-wise grokking experiments on IMDB demonstrate that the critical data size increases as the model size increases. A: Test accuracy variations by hidden layer size and data fraction of the IMDB dataset. The data fraction required for higher accuracy increases as the model size increases. Training acc visualization is presented in Figure \ref{['fig:model_wise_grokking_train_acc']}. B: Average accuracy across all data fractions from 10% to 100%. The white arrows indicate that the average accuracy decreases as the model size increases, suggesting that larger models require more data to maintain performance. The light blue area represents a 95% confidence interval. C: Training curves for models with different layer counts under various data fractions. As the number of layers increases, larger models require larger data sizes for effective generalization.
  • Figure 5: Visualization about how the model transits from memorization to generalization throughout the training process. We visualize the classification layer's weights during the learning process using a 1-layer, encoder-only Transformer on the IMDB dataset. Notably, the parameter distribution evolves from a randomly initialized state to a fixed range of values, which we have categorized into stages from A to F. The transition from memorization to generalization is influenced by weight decay and loss, leading to a decrease in the L2 norm. More explanations of the L2 Norm evolution are in Figure \ref{['fig:grokking_stages']}.
  • ...and 9 more figures