Table of Contents
Fetching ...

Time Transfer: On Optimal Learning Rate and Batch Size In The Infinite Data Limit

Oleg Filatov, Jan Ebert, Jiangtao Wang, Stefan Kesselheim

TL;DR

The paper addresses optimal joint scaling of learning rate $\eta$ and batch size $B$ in the infinite data limit for pretraining transformers, using μP and μTransfer to test transfer across data horizons. It introduces a parametric fit for the optimal rate, $\eta^*(T,B)$, with a time-dependent critical batch size $B_{\mathrm{crit}}(T)$ that grows roughly as $B_{\mathrm{crit}} \propto T$, and shows that optimal batch size $B^*(T)$ increases with $T$ while remaining below $B_{\mathrm{crit}}(T)$; learning-rate sensitivity to suboptimal $\eta$ decreases as $T$ increases and is largely unchanged under μP. These results persist under μP scaling, suggesting a unified joint data-model scaling picture and hinting at invariants such as the noise scale that could govern hyperparameter transfer in the combined infinite data and model size limit. The findings have practical implications for hyperparameter tuning in large-scale pretraining and motivate future theoretical work to formalize joint data-horizon and width-based invariants for scalable learning.

Abstract

One of the main challenges in optimal scaling of large language models (LLMs) is the prohibitive cost of hyperparameter tuning, particularly learning rate $η$ and batch size $B$. While techniques like $μ$P (Yang et al., 2022) provide scaling rules for optimal $η$ transfer in the infinite model size limit, the optimal scaling behavior in the infinite data size limit remains unknown. We fill in this gap by observing for the first time an intricate dependence of optimal $η$ scaling on the pretraining token budget $T$, $B$ and its relation to the critical batch size $B_\mathrm{crit}$, which we measure to evolve as $B_\mathrm{crit} \propto T$. Furthermore, we show that the optimal batch size is positively correlated with $B_\mathrm{crit}$: keeping it fixed becomes suboptimal over time even if learning rate is scaled optimally. Surprisingly, our results demonstrate that the observed optimal $η$ and $B$ dynamics are preserved with $μ$P model scaling, challenging the conventional view of $B_\mathrm{crit}$ dependence solely on loss value. Complementing optimality, we examine the sensitivity of loss to changes in learning rate, where we find the sensitivity to decrease with increase of $T$ and to remain constant with $μ$P model scaling. We hope our results make the first step towards a unified picture of the joint optimal data and model scaling.

Time Transfer: On Optimal Learning Rate and Batch Size In The Infinite Data Limit

TL;DR

The paper addresses optimal joint scaling of learning rate and batch size in the infinite data limit for pretraining transformers, using μP and μTransfer to test transfer across data horizons. It introduces a parametric fit for the optimal rate, , with a time-dependent critical batch size that grows roughly as , and shows that optimal batch size increases with while remaining below ; learning-rate sensitivity to suboptimal decreases as increases and is largely unchanged under μP. These results persist under μP scaling, suggesting a unified joint data-model scaling picture and hinting at invariants such as the noise scale that could govern hyperparameter transfer in the combined infinite data and model size limit. The findings have practical implications for hyperparameter tuning in large-scale pretraining and motivate future theoretical work to formalize joint data-horizon and width-based invariants for scalable learning.

Abstract

One of the main challenges in optimal scaling of large language models (LLMs) is the prohibitive cost of hyperparameter tuning, particularly learning rate and batch size . While techniques like P (Yang et al., 2022) provide scaling rules for optimal transfer in the infinite model size limit, the optimal scaling behavior in the infinite data size limit remains unknown. We fill in this gap by observing for the first time an intricate dependence of optimal scaling on the pretraining token budget , and its relation to the critical batch size , which we measure to evolve as . Furthermore, we show that the optimal batch size is positively correlated with : keeping it fixed becomes suboptimal over time even if learning rate is scaled optimally. Surprisingly, our results demonstrate that the observed optimal and dynamics are preserved with P model scaling, challenging the conventional view of dependence solely on loss value. Complementing optimality, we examine the sensitivity of loss to changes in learning rate, where we find the sensitivity to decrease with increase of and to remain constant with P model scaling. We hope our results make the first step towards a unified picture of the joint optimal data and model scaling.
Paper Structure (38 sections, 11 equations, 21 figures)

This paper contains 38 sections, 11 equations, 21 figures.

Figures (21)

  • Figure 1: (a): Optimal learning rate $\eta^*$ per batch size against a set of pretraining token budgets (see Appendix \ref{['app:fig1-full']} for a full set). Each point is obtained by averaging experimental observations of optimal learning rate values across $\mu$P model family and random seeds, as described in Sec. \ref{['sec:exp-crit']}, with color bands visualizing the corresponding standard deviation. Solid lines represent the fitted theoretical model of li2024surgephenomenonoptimallearning (Eq. \ref{['eq:fit-model']}) as described in Sec. \ref{['sec:exp-crit']}, dashed lines only connect the data points for visualization purposes. We observe an approximately linear growth of $B_\mathrm{crit}$ (see also a dedicated Fig. \ref{['fig:crit-fit']}), defined as the peak position of the fitted curve, in the limit of increased token budget. (b) Transposition of Fig. \ref{['fig:lbl-token-lr']}: evolution of the optimal learning rate with an increase of the pretraining token budget $\eta^*(T)$ for a representative set of batch sizes, in tokens. We observe the fitted model to describe the scaling behavior of low ($B=2^{18}$) and high ($B=2^{26}$) batch sizes, as well as intermediate batch sizes in the high token budget regime. For $B = 2^{18}$, the model reduces to $\eta^* \propto 1/\sqrt{T}$ as discussed in Sec. \ref{['sec:lr-scaling']}, matching the observations.
  • Figure 2: Critical batch size $B_\mathrm{crit}$ (left) and critical learning rate $\eta_\mathrm{crit}$ (right), as extracted from the fit with the power law $p_\mathrm{crit} = a_pT^{\alpha_p} + b_p$, where $p \in \{\eta, B\}$, following the procedure from Appendix \ref{['app:fitting']}, as a function of token budget. Solid line represents the fit result. Dashed line corresponds to the fit with the power exponent fixed to $\alpha_B = 0.5$ (left) and $\alpha_\eta = -0.5$ (right). This model fit is visualized only to illustrate the model variation with the exponent change and its parameters are not used in the main analysis.
  • Figure 3: Validation loss $\mathcal{L}_\mathrm{val}$ for a $(d_\mathrm{model} = d_\mathrm{model}^\mathrm{base} = 1024)$ model training (354M parameters) with an optimally-tuned learning rate as a function of (a) batch size split in pretraining token budgets (b) pretraining token budget split in batch size, both measured in tokens. Inset plots zoom into the optimum region. We observe that (a) optimal batch size (circled markers in the inset plot) evolves in time, by a $\times2^2$ ($B = 2^{18} \to 2^{20}$ tokens) increase with an increase of the budget by $\times 2^5$ ($T = 2^{30} \to 2^{35}$ tokens) (b) smaller batch sizes are gradually phased out to become suboptimal as the token budget increases.
  • Figure 4: Learning rate sensitivity $\mathcal{L}_\mathrm{val} - \mathcal{L}_\mathrm{val}^\mathrm{min}$ as a function of the learning rate deviation from the optimal value $\eta / \eta_\mathrm{optimal}$, measured for batch sizes of $B=2^{18}$ (left column), $2^{20}$ (middle column), and $2^{22}$ (right column) tokens, separately for the $\mu$P base models with width $d_\mathrm{model}^\mathrm{base} = 256$ (top row) and $1024$ (bottom row). The former model amounts to 32M and the latter to 354M trainable parameters. With an increase of the pretraining token budget (different marker styles) we observe a general decrease in the learning rate sensitivity, which is more pronounced for batch sizes $B \in \{2^{20}, 2^{22}\}$ in the critical region (Sec. \ref{['sec:term']}) and for the 354M model. At the largest probed token budget $T = 2^{35}$ tokens, the sensitivity equalizes across the models and batch sizes.
  • Figure 5: Learning rate sensitivity $\mathcal{L}_\mathrm{val} - \mathcal{L}_\mathrm{val}^\mathrm{min}$ as a function of learning rate $\eta$, measured for batch sizes of $B=2^{18}$ (leftmost column), $2^{20}$ (middle left column), $2^{22}$ (middle right column) and $2^{24}$ (rightmost column) tokens, separately for the $\mu$P base models with the width $d_\mathrm{model}^\mathrm{base} = 256$ (top row) and $1024$ (bottom row). Different marker styles correspond to different models within the $\mu$P family, with all the models being evaluated at the data horizon of $T = 2^{35}$ tokens. For the base model with $d_\mathrm{model}^\mathrm{base} = 256$, we scale the width only downwards, while for the base model with $d_\mathrm{model}^\mathrm{base} = 1024$, we scale it both upwards and downwards. We observe no significant difference in the sensitivity across all the $(d_\mathrm{model}^\mathrm{base}, d_\mathrm{model})$ configurations. Note that for the configuration $(B=2^{24}, ~d_\mathrm{model}^\mathrm{base} = 1024)$, the base and $d_\mathrm{model} = 4 \times d_\mathrm{model}^\mathrm{base}$ models share a different random seed compared to all the other models, to illustrate the loss penalty arising from the learning rate optimum variation.
  • ...and 16 more figures