Table of Contents
Fetching ...

Adaptive Sampled Softmax with Inverted Multi-Index: Methods, Theory and Applications

Jin Chen, Jin Zhang, Xu huang, Yi Yang, Defu Lian, Enhong Chen

TL;DR

The paper tackles the computational bottleneck of the full softmax with very large class counts by introducing the MIDX-Sampler, which leverages an inverted multi-index to decompose the softmax into multiple multinomial stages and replaces the final stage with a uniform distribution to achieve sub-linear sampling complexity. It provides a rigorous theoretical framework, including KL-divergence bounds, gradient bias analyses, convergence rates, and generalization error bounds, highlighting how smaller residual logits after quantization yield tighter bias controls. The authors present two practical variants, MIDX-pq (product quantization) and MIDX-rq (residual quantization), and show that learnable codebooks guided by KL-divergence objectives further reduce bias and improve perplexity and ranking metrics. Comprehensive experiments across language modeling, sequential recommendation, and extreme classification demonstrate that MIDX samplers consistently outperform static and kernel-based baselines in both accuracy and efficiency, with the residual-quantization variant often providing the strongest gains. The work thus offers a scalable, theory-grounded solution for large-scale multi-class problems and includes an implementation to enable adoption in real-world systems.

Abstract

The softmax function is a cornerstone of multi-class classification, integral to a wide range of machine learning applications, from large-scale retrieval and ranking models to advanced large language models. However, its computational cost grows linearly with the number of classes, which becomes prohibitively expensive in scenarios with millions or even billions of classes. The sampled softmax, which relies on self-normalized importance sampling, has emerged as a powerful alternative, significantly reducing computational complexity. Yet, its estimator remains unbiased only when the sampling distribution matches the true softmax distribution. To improve both approximation accuracy and sampling efficiency, we propose the MIDX Sampler, a novel adaptive sampling strategy based on an inverted multi-index approach. Concretely, we decompose the softmax probability into several multinomial probabilities, each associated with a specific set of codewords and the last associated with the residual score of queries, thus reducing time complexity to the number of codewords instead of the number of classes. To further boost efficiency, we replace the query-specific residual probability with a simple uniform distribution, simplifying the computation while retaining high performance. Our method is backed by rigorous theoretical analysis, addressing key concerns such as sampling bias, gradient bias, convergence rates, and generalization error bounds. The results demonstrate that a smaller divergence from the ideal softmax distribution leads to faster convergence and improved generalization. Extensive experiments on large-scale language models, sequential recommenders, and extreme multi-class classification tasks confirm that the MIDX-Sampler delivers superior effectiveness and efficiency compared to existing approaches.

Adaptive Sampled Softmax with Inverted Multi-Index: Methods, Theory and Applications

TL;DR

The paper tackles the computational bottleneck of the full softmax with very large class counts by introducing the MIDX-Sampler, which leverages an inverted multi-index to decompose the softmax into multiple multinomial stages and replaces the final stage with a uniform distribution to achieve sub-linear sampling complexity. It provides a rigorous theoretical framework, including KL-divergence bounds, gradient bias analyses, convergence rates, and generalization error bounds, highlighting how smaller residual logits after quantization yield tighter bias controls. The authors present two practical variants, MIDX-pq (product quantization) and MIDX-rq (residual quantization), and show that learnable codebooks guided by KL-divergence objectives further reduce bias and improve perplexity and ranking metrics. Comprehensive experiments across language modeling, sequential recommendation, and extreme classification demonstrate that MIDX samplers consistently outperform static and kernel-based baselines in both accuracy and efficiency, with the residual-quantization variant often providing the strongest gains. The work thus offers a scalable, theory-grounded solution for large-scale multi-class problems and includes an implementation to enable adoption in real-world systems.

Abstract

The softmax function is a cornerstone of multi-class classification, integral to a wide range of machine learning applications, from large-scale retrieval and ranking models to advanced large language models. However, its computational cost grows linearly with the number of classes, which becomes prohibitively expensive in scenarios with millions or even billions of classes. The sampled softmax, which relies on self-normalized importance sampling, has emerged as a powerful alternative, significantly reducing computational complexity. Yet, its estimator remains unbiased only when the sampling distribution matches the true softmax distribution. To improve both approximation accuracy and sampling efficiency, we propose the MIDX Sampler, a novel adaptive sampling strategy based on an inverted multi-index approach. Concretely, we decompose the softmax probability into several multinomial probabilities, each associated with a specific set of codewords and the last associated with the residual score of queries, thus reducing time complexity to the number of codewords instead of the number of classes. To further boost efficiency, we replace the query-specific residual probability with a simple uniform distribution, simplifying the computation while retaining high performance. Our method is backed by rigorous theoretical analysis, addressing key concerns such as sampling bias, gradient bias, convergence rates, and generalization error bounds. The results demonstrate that a smaller divergence from the ideal softmax distribution leads to faster convergence and improved generalization. Extensive experiments on large-scale language models, sequential recommenders, and extreme multi-class classification tasks confirm that the MIDX-Sampler delivers superior effectiveness and efficiency compared to existing approaches.
Paper Structure (43 sections, 14 theorems, 74 equations, 7 figures, 10 tables, 1 algorithm)

This paper contains 43 sections, 14 theorems, 74 equations, 7 figures, 10 tables, 1 algorithm.

Key Result

Theorem 1

Given the query embedding $\bm{z} = [\bm{z}^1 \oplus \bm{z}^2]$, the $i$-th class embedding is denoted as $\bm{q}_i = [\bm{c}^1_{k_2} \oplus \bm{c}^2_{k_2}] + \tilde{\bm{q}}_i$, where $\tilde{\bm{q}}_i$ denotes the residual vector. $\Omega_{k_1, k_2}$ denotes the set of classes grouped to $\bm{c}^1_

Figures (7)

  • Figure 1: Procedure of sampling a class given a query embedding through MIDX samplers. The subvectors $\bm{z}_1$ and $\bm{z}_2$ are derived depending on different quantizers, e.g., the product quantizers and the residual quantizers. The example follows as 1. Select the first codeword depends on the probability $P^1(\cdot)$; 2. Select the second codeword depends on the probability $P^2(\cdot|c_5^1)$; 3. Sample classes from the union set $\Omega(c_5^1, c_3^2)$, which includes the classes who are assigned to the 5-th codebook with the first subvector and to the 3-rd codebook with the second subvector.
  • Figure 2: Comparison with samplers
  • Figure 3: Effect of codeword numbers
  • Figure 4: Sampling probabilities with random initialization.
  • Figure 5: Sampling probabilities with well-trained embeddings.
  • ...and 2 more figures

Theorems & Definitions (14)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Theorem 4
  • Theorem 5
  • Theorem 6
  • Theorem 7
  • Theorem 8
  • Theorem 9
  • Lemma 1
  • ...and 4 more