Table of Contents
Fetching ...

Vectorizing the Trie: Efficient Constrained Decoding for LLM-based Generative Retrieval on Accelerators

Zhengyang Su, Isay Katsman, Yueqi Wang, Ruining He, Lukasz Heldt, Raghunandan Keshavan, Shao-Chuan Wang, Xinyang Yi, Mingyan Gao, Onkar Dalal, Lichan Hong, Ed Chi, Ningren Han

TL;DR

This work introduces STATIC (Sparse Transition Matrix-Accelerated Trie Index for Constrained Decoding), an efficient and scalable constrained decoding technique designed specifically for high-throughput LLM-based generative retrieval on TPUs/GPUs and enables the first production-scale deployment of strictly constrained generative retrieval.

Abstract

Generative retrieval has emerged as a powerful paradigm for LLM-based recommendation. However, industrial recommender systems often benefit from restricting the output space to a constrained subset of items based on business logic (e.g. enforcing content freshness or product category), which standard autoregressive decoding cannot natively support. Moreover, existing constrained decoding methods that make use of prefix trees (Tries) incur severe latency penalties on hardware accelerators (TPUs/GPUs). In this work, we introduce STATIC (Sparse Transition Matrix-Accelerated Trie Index for Constrained Decoding), an efficient and scalable constrained decoding technique designed specifically for high-throughput LLM-based generative retrieval on TPUs/GPUs. By flattening the prefix tree into a static Compressed Sparse Row (CSR) matrix, we transform irregular tree traversals into fully vectorized sparse matrix operations, unlocking massive efficiency gains on hardware accelerators. We deploy STATIC on a large-scale industrial video recommendation platform serving billions of users. STATIC produces significant product metric impact with minimal latency overhead (0.033 ms per step and 0.25% of inference time), achieving a 948x speedup over a CPU trie implementation and a 47-1033x speedup over a hardware-accelerated binary-search baseline. Furthermore, the runtime overhead of STATIC remains extremely low across a wide range of practical configurations. To the best of our knowledge, STATIC enables the first production-scale deployment of strictly constrained generative retrieval. In addition, evaluation on academic benchmarks demonstrates that STATIC can considerably improve cold-start performance for generative retrieval. Our code is available at https://github.com/youtube/static-constraint-decoding.

Vectorizing the Trie: Efficient Constrained Decoding for LLM-based Generative Retrieval on Accelerators

TL;DR

This work introduces STATIC (Sparse Transition Matrix-Accelerated Trie Index for Constrained Decoding), an efficient and scalable constrained decoding technique designed specifically for high-throughput LLM-based generative retrieval on TPUs/GPUs and enables the first production-scale deployment of strictly constrained generative retrieval.

Abstract

Generative retrieval has emerged as a powerful paradigm for LLM-based recommendation. However, industrial recommender systems often benefit from restricting the output space to a constrained subset of items based on business logic (e.g. enforcing content freshness or product category), which standard autoregressive decoding cannot natively support. Moreover, existing constrained decoding methods that make use of prefix trees (Tries) incur severe latency penalties on hardware accelerators (TPUs/GPUs). In this work, we introduce STATIC (Sparse Transition Matrix-Accelerated Trie Index for Constrained Decoding), an efficient and scalable constrained decoding technique designed specifically for high-throughput LLM-based generative retrieval on TPUs/GPUs. By flattening the prefix tree into a static Compressed Sparse Row (CSR) matrix, we transform irregular tree traversals into fully vectorized sparse matrix operations, unlocking massive efficiency gains on hardware accelerators. We deploy STATIC on a large-scale industrial video recommendation platform serving billions of users. STATIC produces significant product metric impact with minimal latency overhead (0.033 ms per step and 0.25% of inference time), achieving a 948x speedup over a CPU trie implementation and a 47-1033x speedup over a hardware-accelerated binary-search baseline. Furthermore, the runtime overhead of STATIC remains extremely low across a wide range of practical configurations. To the best of our knowledge, STATIC enables the first production-scale deployment of strictly constrained generative retrieval. In addition, evaluation on academic benchmarks demonstrates that STATIC can considerably improve cold-start performance for generative retrieval. Our code is available at https://github.com/youtube/static-constraint-decoding.
Paper Structure (38 sections, 6 equations, 4 figures, 4 tables, 2 algorithms)

This paper contains 38 sections, 6 equations, 4 figures, 4 tables, 2 algorithms.

Figures (4)

  • Figure 1: This figure showcases the full STATIC pipeline. Figures 1a and 1b present the prefix tree construction for the case $\mathcal{V} = \{1,2,3\}$, $L = 3$, and restricted vocabulary $\mathcal{C} = \{(1, 2, 1), (3, 1, 2), (3, 1, 3)\}$. We prepend the letters $A, B, C$ to the semantic tokens to denote first, second, and third levels, respectively. The corresponding transition matrix is then shown in Figure 1c, with explicit labels and color-coding. Figure 1d presents the sparse matrix (CSR) representation of the full transition matrix. Figure 1e shows how the sparse matrix is applied to constrain the vocabulary at decoding time.
  • Figure 2: Scaling of different constraint decoding methods relative to constraint set size (log-log scale), fixing $|\mathcal{V}|=2048$. STATIC considerably outperforms existing methods.
  • Figure 3: Scaling of different constraint decoding methods relative to SID vocabulary size $|\mathcal{V}|$ (log-log scale) at $|\mathcal{C}|=10^7$. STATIC compares favorably for all tested SID vocab sizes.
  • Figure 4: Scaling of the STATIC masking kernel with respect to the max branch factor (log-log scale). We plot means and standard deviations (shaded region) over $100$ trials. For each trial, the token vocabulary size is set to the branch factor, while the number of constraints is fixed at $|\mathcal{C}|=10^6$. The STATIC method exhibits asymptotically linear $\mathcal{O}(B)$ scaling.