Table of Contents
Fetching ...

Stable and low-precision training for large-scale vision-language models

Mitchell Wortsman, Tim Dettmers, Luke Zettlemoyer, Ari Morcos, Ali Farhadi, Ludwig Schmidt

TL;DR

The paper tackles the challenges of fast and stable training for large-scale vision-language models by introducing SwitchBack, a quantized linear layer that delivers substantial speedups with minimal accuracy loss, and by diagnosing loss spikes tied to out-of-date second-moment estimates. It combines 8-bit training (int8/FP8) with a quantization strategy that preserves weight-gradient precision and demonstrates speedups of 13–25% on CLIP ViT-Huge, including FP8 simulations. To address stability, it proposes StableAdamW, a hybrid optimizer that incorporates AdaFactor-style update clipping, reducing loss spikes and outperforming gradient clipping at scale. Additional insights include the viability of FP8 training under zero-initialized layer-scale, and open-source Triton kernels to foster community development. Overall, the work provides practical strategies for making large multi-modal models trainable faster and more stably at scale, along with resources to extend this progress.

Abstract

We introduce new methods for 1) accelerating and 2) stabilizing training for large language-vision models. 1) For acceleration, we introduce SwitchBack, a linear layer for int8 quantized training which provides a speed-up of 13-25% while matching the performance of bfloat16 training within 0.1 percentage points for the 1B parameter CLIP ViT-Huge -- the largest int8 training to date. Our main focus is int8 as GPU support for float8 is rare, though we also analyze float8 training through simulation. While SwitchBack proves effective for float8, we show that standard techniques are also successful if the network is trained and initialized so that large feature magnitudes are discouraged, which we accomplish via layer-scale initialized with zeros. 2) For stability, we analyze loss spikes and find they consistently occur 1-8 iterations after the squared gradients become under-estimated by their AdamW second moment estimator. As a result, we recommend an AdamW-Adafactor hybrid which avoids loss spikes when training a CLIP ViT-Huge model and outperforms gradient clipping at the scales we test.

Stable and low-precision training for large-scale vision-language models

TL;DR

The paper tackles the challenges of fast and stable training for large-scale vision-language models by introducing SwitchBack, a quantized linear layer that delivers substantial speedups with minimal accuracy loss, and by diagnosing loss spikes tied to out-of-date second-moment estimates. It combines 8-bit training (int8/FP8) with a quantization strategy that preserves weight-gradient precision and demonstrates speedups of 13–25% on CLIP ViT-Huge, including FP8 simulations. To address stability, it proposes StableAdamW, a hybrid optimizer that incorporates AdaFactor-style update clipping, reducing loss spikes and outperforming gradient clipping at scale. Additional insights include the viability of FP8 training under zero-initialized layer-scale, and open-source Triton kernels to foster community development. Overall, the work provides practical strategies for making large multi-modal models trainable faster and more stably at scale, along with resources to extend this progress.

Abstract

We introduce new methods for 1) accelerating and 2) stabilizing training for large language-vision models. 1) For acceleration, we introduce SwitchBack, a linear layer for int8 quantized training which provides a speed-up of 13-25% while matching the performance of bfloat16 training within 0.1 percentage points for the 1B parameter CLIP ViT-Huge -- the largest int8 training to date. Our main focus is int8 as GPU support for float8 is rare, though we also analyze float8 training through simulation. While SwitchBack proves effective for float8, we show that standard techniques are also successful if the network is trained and initialized so that large feature magnitudes are discouraged, which we accomplish via layer-scale initialized with zeros. 2) For stability, we analyze loss spikes and find they consistently occur 1-8 iterations after the squared gradients become under-estimated by their AdamW second moment estimator. As a result, we recommend an AdamW-Adafactor hybrid which avoids loss spikes when training a CLIP ViT-Huge model and outperforms gradient clipping at the scales we test.
Paper Structure (30 sections, 7 equations, 21 figures, 5 algorithms)

This paper contains 30 sections, 7 equations, 21 figures, 5 algorithms.

Figures (21)

  • Figure 1: We introduce SwitchBack, a linear layer for low-precision training. (Left) SwitchBack for int8 training matches the zero-shot ImageNet deng2009imagenet accuracy of standard bfloat16 training within 0.1 percentage point for CLIP ViT-Huge radford2021learningdosovitskiy2021an and outperforms LLM.int8() dettmers2022llm. (Right) For float8 (fp8) training micikevicius2022fp8, a baseline which uses tensor-wise quantization diverges for large models while SwitchBack matches the baseline. In these large-model, small-data experiments, our focus is on comparing methods and not final model accuracy, so we use short runs which makes it feasible to run many experiments.
  • Figure 2: Loss curves for the CLIP ViT-Base and CLIP ViT-Huge models evaluated in Figure \ref{['fig:fig1']}. The left two plots display results for int8 training while the right two plots display results for float8 (fp8) training.
  • Figure 3: (Left) Individually profiling operations which constitute a forward and backward pass in a linear layer for i) SwitchBack using triton kernels and ii) an fp16 baseline using $\mathsf{torch.matmul}$. Times are averaged over a linear layer from $\mathsf{dim}$ to $4\cdot\mathsf{dim}$ and a linear layer from $4\cdot\mathsf{dim}$ to $\mathsf{dim}$---representative of the linear layers in a transformer MLP. (Right) The % speedup of SwitchBack over a standard fp16 linear layer when all operations in Figure \ref{['fig:speed1']} (left) are summed.
  • Figure 4: (Left) Measuring the % of time occupied by quantize operations for a SwitchBack linear layer, which is usually less than $20\%$ and decreases with $\mathsf{dim}$. (Right) Benchmarking speedups for end-to-end CLIP training on a single node (with 4 A100 GPUs, per-GPU batch size 256, and gradient checkpointing) for various model sizes when replacing all linear operations in the transformer with SwitchBack (i.e., key, query, value, and out projections as well as the MLP). speedups reported over i) a custom linear layer implemented with $\mathsf{torch.autograd}$ (Algorithm \ref{['alg:code5']}), which matches our implementation of SwitchBack that uses $\mathsf{torch.autograd}$, and ii) using the standard PyTorch $\mathsf{nn.Linear}$ which includes additional background C++/CUDA optimizations which we do not replicate. LLM.int8() dettmers2022llm does not provide speed-ups over the $\mathsf{torch.autograd}$ or $\mathsf{nn.Linear}$ baseline at this scale---we compare the speed of SwitchBack and LLM.int8() in Figure \ref{['fig:speedllmint8']}.
  • Figure 5: (Left) Training CLIP ViT-Large models with simulated fp8 precision using tensor-wise quantization for the inputs, weights, and gradients. All methods we try diverge except for using zero-init layerscaletouvron2021going, which multiplies the output of each self-attention or mlp block with a learnable vector initialized to zero. (Right) Examining feature magnitudes (i.e., the average absolute value of the output for transformer block $k$) for CLIP ViT-Huge at the beginning (init) and end of training. This suggest why zero-init layer scale enables float8 training---zero-init layer scale prevents high feature magnitudes which may cause issues for low precision training dettmers2022llm. Without the intervention, the average feature magnitude becomes large for later blocks.
  • ...and 16 more figures