Scaling Embedding Layers in Language Models
Da Yu, Edith Cohen, Badih Ghazi, Yangsibo Huang, Pritish Kamath, Ravi Kumar, Daogao Liu, Chiyuan Zhang
TL;DR
SCONE introduces a scalable, off-accelerated embedding expansion by learning contextualized $n$-gram embeddings with a separate transformer ($\mathcal{A}_{\mathrm{f\text{-}gram}}$) and caching their outputs for inference via $\mathcal{F}$. By decoupling the $n$-gram embeddings from the token vocabulary, SCONE enables two new scaling axes—more $f$-grams and larger $\mathcal{A}_{\mathrm{f\text{-}gram}}$ models—without increasing inference-time FLOPS or accelerator memory. Empirical results on GPT-2–scale pretraining show perplexity improvements and strong zero-shot gains on downstream tasks, with sizable reductions in inference cost relative to larger baselines. The approach enables efficient capacity expansion for latency-sensitive deployments by shifting heavy embedding learning to training and caching to off-accelerator storage, while preserving a fixed inference footprint.
Abstract
We propose $SCONE$ ($S$calable, $C$ontextualized, $O$ffloaded, $N$-gram $E$mbedding), a new method for extending input embedding layers to enhance language model performance. To avoid increased decoding costs, $SCONE$ retains the original vocabulary while introducing embeddings for a set of frequent n-grams. These embeddings provide contextualized representation for each input token and are learned with a separate model during training. After training, embeddings are precomputed and stored in off-accelerator memory; during inference, querying them has minimal impact on latency due to the low complexity of embedding lookups. $SCONE$ enables two new scaling strategies: increasing the number of n-gram embeddings and scaling the model used to learn them, both while maintaining fixed accelerator usage during inference (in terms of FLOPS and memory). We show that scaling both aspects enables a model with 1B accelerator-resident parameters to outperform a 1.9B-parameter baseline across diverse corpora, while using only about half the FLOPS and accelerator memory during inference.
