Table of Contents
Fetching ...

Collage: Light-Weight Low-Precision Strategy for LLM Training

Tao Yu, Gaurav Gupta, Karthick Gopalswamy, Amith Mamidala, Hao Zhou, Jeffrey Huynh, Youngsuk Park, Ron Diamant, Anoop Deoras, Luke Huan

TL;DR

The paper tackles the high resource demands of training large language models by introducing Collage, a low-precision training framework that uses multi-component floating-point (MCF) to capture and compensate for rounding errors without FP32 master weights. It formalizes imprecision through metrics like effective descent quality (EDQ) and demonstrates how a precision-aware AdamW variant operating entirely in BF16 with MCF components can match or exceed FP32-master baselines while delivering substantial throughput (up to $3.7\times$) and memory savings (up to ~${22.8}\%$ on large GPT/OpenLLaMA models). The approach is validated across BERT, RoBERTa, GPT-scale models, and OpenLLaMA-7B, showing robust performance across varying beta_2 settings and model sizes, with Collage-plus often providing the best trade-off between accuracy, speed, and memory. The work implies practical benefits for scalable, energy-efficient LLM training and sets the stage for extending to even lower precision (e.g., BF16 with FP8) and broader hardware support.

Abstract

Large models training is plagued by the intense compute cost and limited hardware memory. A practical solution is low-precision representation but is troubled by loss in numerical accuracy and unstable training rendering the model less useful. We argue that low-precision floating points can perform well provided the error is properly compensated at the critical locations in the training process. We propose Collage which utilizes multi-component float representation in low-precision to accurately perform operations with numerical errors accounted. To understand the impact of imprecision to training, we propose a simple and novel metric which tracks the lost information during training as well as differentiates various precision strategies. Our method works with commonly used low-precision such as half-precision ($16$-bit floating points) and can be naturally extended to work with even lower precision such as $8$-bit. Experimental results show that pre-training using Collage removes the requirement of using $32$-bit floating-point copies of the model and attains similar/better training performance compared to $(16, 32)$-bit mixed-precision strategy, with up to $3.7\times$ speedup and $\sim 15\%$ to $23\%$ less memory usage in practice.

Collage: Light-Weight Low-Precision Strategy for LLM Training

TL;DR

The paper tackles the high resource demands of training large language models by introducing Collage, a low-precision training framework that uses multi-component floating-point (MCF) to capture and compensate for rounding errors without FP32 master weights. It formalizes imprecision through metrics like effective descent quality (EDQ) and demonstrates how a precision-aware AdamW variant operating entirely in BF16 with MCF components can match or exceed FP32-master baselines while delivering substantial throughput (up to ) and memory savings (up to ~ on large GPT/OpenLLaMA models). The approach is validated across BERT, RoBERTa, GPT-scale models, and OpenLLaMA-7B, showing robust performance across varying beta_2 settings and model sizes, with Collage-plus often providing the best trade-off between accuracy, speed, and memory. The work implies practical benefits for scalable, energy-efficient LLM training and sets the stage for extending to even lower precision (e.g., BF16 with FP8) and broader hardware support.

Abstract

Large models training is plagued by the intense compute cost and limited hardware memory. A practical solution is low-precision representation but is troubled by loss in numerical accuracy and unstable training rendering the model less useful. We argue that low-precision floating points can perform well provided the error is properly compensated at the critical locations in the training process. We propose Collage which utilizes multi-component float representation in low-precision to accurately perform operations with numerical errors accounted. To understand the impact of imprecision to training, we propose a simple and novel metric which tracks the lost information during training as well as differentiates various precision strategies. Our method works with commonly used low-precision such as half-precision (-bit floating points) and can be naturally extended to work with even lower precision such as -bit. Experimental results show that pre-training using Collage removes the requirement of using -bit floating-point copies of the model and attains similar/better training performance compared to -bit mixed-precision strategy, with up to speedup and to less memory usage in practice.
Paper Structure (46 sections, 1 theorem, 5 equations, 17 figures, 12 tables, 7 algorithms)

This paper contains 46 sections, 1 theorem, 5 equations, 17 figures, 12 tables, 7 algorithms.

Key Result

Theorem 4.1

Let two floating-point numbers $a,b$ be $|a|\geq |b|$, produces a MCF expansion $(x,y)$ such that $a+b=x+y$, where $x\gets\mathcal{F}^{P}(a\oplus b)$ is the floating-point sum with precision $P$, $y\gets\mathcal{F}^{P}\left(b\ominus\mathcal{F}^{P}(x\ominus a)\right)=a+b-\mathcal{F}^{P}(a\oplus b)$ is the rounding error. Also, $y$ is upper-bounded such t

Figures (17)

  • Figure 1: Left:Collage uses a strict low-precision floating-point (such as BF16) optimization loop without ever needing to upcast to FP32 like in the mixed-precision with master weights (red thick loop). The model weights in Collage are represented as multi-component float (MCF) instead of "standard float". Right: Total bytes/parameter savings for Collage without taking the FP32 upcasting route. The memory savings and uncompromising use of low-precision results in speed-up as seen in Table \ref{['tab:speedup']}.
  • Figure 2: Bert-base-uncased phase-1 pretraining with settings as described in Section \ref{['subsec:pretrain-bert-roberta']}. Left: Model parameter L2 norm vs iterations for BF16 and FP32 master weights strategy. Right: update $\Delta\bm{\theta}_t$ L2 norm across iterations. The model parameter norm and update norm are at different scales, for example, $\sim 450$ vs $\sim 0.5$ at $14$k iterations, which is a factor of $900$ and causes lost arithmetic.
  • Figure 3: BERT phase-$1$ pre-training (see Appendix \ref{['appssec:bert_roberta']} for details). Left: Imprecision percentage ($\%$) measured as the percentage of lost arithmetic for all model parameters, i.e., not updated, vs iterations for BF$16$. Middle: Training perplexity vs iterations for various precision strategies (see Table \ref{['tab:precision-strategies-breakdown']}). Additionally, we evaluate "FP32" as 32-bit counterpart of option A, and BF16-Kahan as Kahan-sum zamirai2020revisiting with BF16 parameters. Right: Effective descent quality ($\mathop{\mathrm{EDQ}}\limits$) in \ref{['eqn:edq_defn']} vs iterations to measure loss in information at the optimizer step for different precision strategies. BF16-Collage-plus training perplexity and $\mathop{\mathrm{EDQ}}\limits$overlaps with the best "FP32", and "BF16 + FP32 MW" with less bytes/parameter.
  • Figure 4: GPU peak memory in GB vs model size. GPT-$125$M is hosted on $1$ NVIDIA A$100$$40$GB, while all other models were hosted on $8\times$ A$100$$40$GB using tensor-parallelism $8$.
  • Figure 5: Openllama 7B pretraining (see settings in Section \ref{['ssec:llama_gpt_results']}) with $\beta_2=0.95$. Left: Training perplexity for different precision strategies listed in Table \ref{['tab:precision-strategies-breakdown']}. Right: Model gradient L2 norm across iterations for different strategies. The Collage formations overlap with heavy-weighted FP32 master weights strategy.
  • ...and 12 more figures

Theorems & Definitions (7)

  • Definition 2.1
  • Definition 3.1: $\mathop{\mathrm{ulp}}\limits$ muller2018handbook
  • Definition 3.2: Lost Arithmetic
  • Definition 3.3: Effective Descent Quality
  • Theorem 4.1: Fast2Sum dekker1971float
  • Remark 5.1
  • Remark 5.2