Collage: Light-Weight Low-Precision Strategy for LLM Training
Tao Yu, Gaurav Gupta, Karthick Gopalswamy, Amith Mamidala, Hao Zhou, Jeffrey Huynh, Youngsuk Park, Ron Diamant, Anoop Deoras, Luke Huan
TL;DR
The paper tackles the high resource demands of training large language models by introducing Collage, a low-precision training framework that uses multi-component floating-point (MCF) to capture and compensate for rounding errors without FP32 master weights. It formalizes imprecision through metrics like effective descent quality (EDQ) and demonstrates how a precision-aware AdamW variant operating entirely in BF16 with MCF components can match or exceed FP32-master baselines while delivering substantial throughput (up to $3.7\times$) and memory savings (up to ~${22.8}\%$ on large GPT/OpenLLaMA models). The approach is validated across BERT, RoBERTa, GPT-scale models, and OpenLLaMA-7B, showing robust performance across varying beta_2 settings and model sizes, with Collage-plus often providing the best trade-off between accuracy, speed, and memory. The work implies practical benefits for scalable, energy-efficient LLM training and sets the stage for extending to even lower precision (e.g., BF16 with FP8) and broader hardware support.
Abstract
Large models training is plagued by the intense compute cost and limited hardware memory. A practical solution is low-precision representation but is troubled by loss in numerical accuracy and unstable training rendering the model less useful. We argue that low-precision floating points can perform well provided the error is properly compensated at the critical locations in the training process. We propose Collage which utilizes multi-component float representation in low-precision to accurately perform operations with numerical errors accounted. To understand the impact of imprecision to training, we propose a simple and novel metric which tracks the lost information during training as well as differentiates various precision strategies. Our method works with commonly used low-precision such as half-precision ($16$-bit floating points) and can be naturally extended to work with even lower precision such as $8$-bit. Experimental results show that pre-training using Collage removes the requirement of using $32$-bit floating-point copies of the model and attains similar/better training performance compared to $(16, 32)$-bit mixed-precision strategy, with up to $3.7\times$ speedup and $\sim 15\%$ to $23\%$ less memory usage in practice.
