From Projection to Prediction: Beyond Logits for Scalable Language Models
Jianbing Dong, Jianbin Chang
TL;DR
This paper tackles a fundamental bottleneck in large language model training: the need to materialize the full logits tensor $Z$ when projecting hidden states to vocabulary logits. It introduces a fused projection-prediction kernel that computes the loss directly from hidden states and target tokens, using a streaming, numerically stable softmax to avoid storing $Z$ and reducing memory to $O(B T)$. Window-based strategies and compatibility with data, tensor, and sequence parallelism further enhance GPU occupancy and scalability. Empirical results show substantial latency reductions (e.g., >40%) and dramatic memory savings (often >95%) at scale, enabling larger batch sizes and longer sequences without accuracy loss. The approach highlights a practical systems optimization with broad applicability to loss objectives beyond cross-entropy and points to compiler-assisted kernel generation for wider, cross-platform deployment.
Abstract
Training Large Language Models (LLMs) typically involves a two-stage pipeline at the output layer: hidden states are projected into vocabulary logits via a linear transformation (lm_head), followed by cross-entropy loss computation against target tokens. While conceptually simple, this design incurs substantial overhead. The intermediate logits tensor, with dimensions proportional to batch size, sequence length, and vocabulary size, must be fully materialized in GPU memory, even though only one target token per position is ultimately used. This leads to significant memory footprint and bandwidth comsumption, limiting scalability and slowing training throughput. In this work, we introduce a novel approach to integrates the output projection and loss prediction into a single operation. By directly computing the loss from hidden states and target tokens, our approach bypasses explicit logits materialization. This design reduces memory usage and alleviates bandwidth pressure. Experiments on LLM training demonstrate that our method achieves substantial memory savings and measurable speedups compared to the standard two-stage pipeline, enabling large batch sizes and longer sequences without sacrificing accuracy. Our work highlights the benefits of rethinking the boundary between projection and prediction, offering a practical systems optimization for efficient LLM training.
