Table of Contents
Fetching ...

Cut Your Losses in Large-Vocabulary Language Models

Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, Philipp Krähenbühl

TL;DR

This paper introduces Cut Cross-Entropy (CCE), a memory-efficient method for training large-vocabulary language models by avoiding full logits materialization. By reformulating the loss into an indexed matmul for the ground-truth token and a log-sum-exp computed on the fly in on-chip SRAM, CCE reduces the loss computation memory from tens of gigabytes to around a megabyte, while preserving speed and convergence. It combines memory-efficient forward and backward kernels, gradient filtering, and vocabulary sorting to leverage softmax sparsity, enabling dramatically larger effective batch sizes with minimal overhead. The approach yields substantial practical impact for training frontier LLMs by balancing memory-use with throughput, and it generalizes to other large-classification scenarios.

Abstract

As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.

Cut Your Losses in Large-Vocabulary Language Models

TL;DR

This paper introduces Cut Cross-Entropy (CCE), a memory-efficient method for training large-vocabulary language models by avoiding full logits materialization. By reformulating the loss into an indexed matmul for the ground-truth token and a log-sum-exp computed on the fly in on-chip SRAM, CCE reduces the loss computation memory from tens of gigabytes to around a megabyte, while preserving speed and convergence. It combines memory-efficient forward and backward kernels, gradient filtering, and vocabulary sorting to leverage softmax sparsity, enabling dramatically larger effective batch sizes with minimal overhead. The approach yields substantial practical impact for training frontier LLMs by balancing memory-use with throughput, and it generalizes to other large-classification scenarios.

Abstract

As language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.

Paper Structure

This paper contains 21 sections, 6 equations, 7 figures, 5 tables, 4 algorithms.

Figures (7)

  • Figure 1: Memory use and maximum attainable batch size (in millions of tokens) for a variety of frontier models on a 16-GPU (80 GB each) fully-sharded data-parallel setup zero2020 with activation checkpointing chen2016checkpointing and a mixed-precision 16-bit (fp16/bf16) AdamW optimizer kingma2015adamloshchilov2019adamw. For each model, we break its memory use down into weights and optimizer states, activation checkpoints, and the log-probabilities computed by the cross-entropy loss layer. Our Cut Cross-Entropy (CCE) enables increasing the batch size by 1.5x (Llama 2 13B) to 10x (GPT 2, Gemma 2 2B), with no sacrifice in speed or convergence. Exact values in \ref{['tab:teaser-data']}.
  • Figure 2: Access patterns and computation of blockwise (a) indexed matrix multiplication, (b) linear-log-sum-exp forward pass, and (c) linear-log-sum-exp backward pass. See \ref{['alg:imatmul', 'alg:lse_fwd', 'alg:lse_bck']} for the corresponding algorithms.
  • Figure 3: Average probability for the $i$th most likely token, log-log plot. The probabilities very quickly vanish below numerical precision.
  • Figure 4: Training loss curves for four models on the Alpaca dataset taori2023alpaca. The loss curves for CCE and torch.compile are nearly indistinguishable, showing that the gradient filtering in CCE does not impair convergence. Results averaged over 5 seeds.
  • Figure 5: Validation perplexity curves for four models on trained using 5% of the Open WebText dataset gokaslan2019openweb. The validation set is a 0.25% subset of Open WebText that does not overlap with the train set. We find that CCE-Kahan-FullC matches torch.compile. Results averaged over 5 seeds.
  • ...and 2 more figures