Stable and low-precision training for large-scale vision-language models
Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari Morcos, Ali Farhadi, Ludwig Schmidt
TL;DR
The paper tackles the challenges of fast and stable training for large-scale vision-language models by introducing SwitchBack, a quantized linear layer that delivers substantial speedups with minimal accuracy loss, and by diagnosing loss spikes tied to out-of-date second-moment estimates. It combines 8-bit training (int8/FP8) with a quantization strategy that preserves weight-gradient precision and demonstrates speedups of 13–25% on CLIP ViT-Huge, including FP8 simulations. To address stability, it proposes StableAdamW, a hybrid optimizer that incorporates AdaFactor-style update clipping, reducing loss spikes and outperforming gradient clipping at scale. Additional insights include the viability of FP8 training under zero-initialized layer-scale, and open-source Triton kernels to foster community development. Overall, the work provides practical strategies for making large multi-modal models trainable faster and more stably at scale, along with resources to extend this progress.
Abstract
We introduce new methods for 1) accelerating and 2) stabilizing training for large language-vision models. 1) For acceleration, we introduce SwitchBack, a linear layer for int8 quantized training which provides a speed-up of 13-25% while matching the performance of bfloat16 training within 0.1 percentage points for the 1B parameter CLIP ViT-Huge -- the largest int8 training to date. Our main focus is int8 as GPU support for float8 is rare, though we also analyze float8 training through simulation. While SwitchBack proves effective for float8, we show that standard techniques are also successful if the network is trained and initialized so that large feature magnitudes are discouraged, which we accomplish via layer-scale initialized with zeros. 2) For stability, we analyze loss spikes and find they consistently occur 1-8 iterations after the squared gradients become under-estimated by their AdamW second moment estimator. As a result, we recommend an AdamW-Adafactor hybrid which avoids loss spikes when training a CLIP ViT-Huge model and outperforms gradient clipping at the scales we test.
