Table of Contents
Fetching ...

Amortizing Maximum Inner Product Search with Learned Support Functions

Theo X. Olausson, João Monteiro, Michal Klein, Marco Cuturi

Abstract

Maximum inner product search (MIPS) is a crucial subroutine in machine learning, requiring the identification of key vectors that align best with a given query. We propose amortized MIPS: a learning-based approach that trains neural networks to directly predict MIPS solutions, amortizing the computational cost of matching queries (drawn from a fixed distribution) to a fixed set of keys. Our key insight is that the MIPS value function, the maximal inner product between a query and keys, is also known as the support function of the set of keys. Support functions are convex, 1-homogeneous and their gradient w.r.t. the query is exactly the optimal key in the database. We approximate the support function using two complementary approaches: (1) we train an input-convex neural network (SupportNet) to model the support function directly; the optimal key can be recovered via (autodiff) gradient computation, and (2) we regress directly the optimal key from the query using a vector valued network (KeyNet), bypassing gradient computation entirely at inference time. To learn a SupportNet, we combine score regression with gradient matching losses, and propose homogenization wrappers that enforce the positive 1-homogeneity of a neural network, theoretically linking function values to gradients. To train a KeyNet, we introduce a score consistency loss derived from the Euler theorem for homogeneous functions. Our experiments show that learned SupportNet or KeyNet achieve high match rates and open up new directions to compress databases with a specific query distribution in mind.

Amortizing Maximum Inner Product Search with Learned Support Functions

Abstract

Maximum inner product search (MIPS) is a crucial subroutine in machine learning, requiring the identification of key vectors that align best with a given query. We propose amortized MIPS: a learning-based approach that trains neural networks to directly predict MIPS solutions, amortizing the computational cost of matching queries (drawn from a fixed distribution) to a fixed set of keys. Our key insight is that the MIPS value function, the maximal inner product between a query and keys, is also known as the support function of the set of keys. Support functions are convex, 1-homogeneous and their gradient w.r.t. the query is exactly the optimal key in the database. We approximate the support function using two complementary approaches: (1) we train an input-convex neural network (SupportNet) to model the support function directly; the optimal key can be recovered via (autodiff) gradient computation, and (2) we regress directly the optimal key from the query using a vector valued network (KeyNet), bypassing gradient computation entirely at inference time. To learn a SupportNet, we combine score regression with gradient matching losses, and propose homogenization wrappers that enforce the positive 1-homogeneity of a neural network, theoretically linking function values to gradients. To train a KeyNet, we introduce a score consistency loss derived from the Euler theorem for homogeneous functions. Our experiments show that learned SupportNet or KeyNet achieve high match rates and open up new directions to compress databases with a specific query distribution in mind.
Paper Structure (31 sections, 17 equations, 8 figures, 1 table)

This paper contains 31 sections, 17 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: In this figure, the set $\mathcal{Y}$ of keys (shown as dots) consists of 5 points. The support function, represented as a contourplot on the left, is a convex, piecewise-linear function of the query $\mathbf{x}$. The gradient $\nabla \sigma_{\mathcal{Y}}(\mathbf{x})$ at any query equals the database point $\mathbf{y}^\star$ that maximizes $\langle \mathbf{x}, \mathbf{y} \rangle$. The gradient $\nabla_{\mathbf{x}}\sigma_{\mathcal{Y}}(\mathbf{x})$ is exactly the key to which a query is matched (here represented as small squares). While SupportNet is trained to model the contour plot on the left (modeling a piecewise affine function), KeyNet is trained to map inputs to optimal keys directly, as shown on the right plot (a piecewise constant vector-valued function).
  • Figure 2: Results for the routing experiment described in \ref{['sec:exp-routing']}, on Quora (left, $n\approx 500$k) and NQ (right, $n\approx 2.5$M). Here SupportNet and KeyNet are used to predict the support function of a query on $c=10$ clusters of the entire set of keys $\mathcal{Y}$, followed by exhaustive search within the top scored cluster. For the baseline clustering approach, the cluster is selected first using the top-scoring centroid to the query. These results show that multiple models (trained with various variants of depth $L$, size $\rho$, peak learning rate etc.) achieve consistently a better routing accuracy with a lower FLOPS budget, highlighting the stability of the performance of SupportNet and KeyNet training pipelines to various hyperparameters. Our recommendation, in any practical deployment scenario, would be to select one of those very best performing models depending on the trade-off sought in practice. Markers outlined with a black line correspond to settings in which the input $\mathbf{x}$ is re-injected every 4 layers ($n_x \approx L/4$) whereas no outline indicates reinjection at every layer ($n_x = L$).
  • Figure 3: $\mathcal{E}_{\text{rel}}$ vs. MRR metric on various model sizes and depths for SupportNet and KeyNet at the end of training on FIQA (left) and QUORA (right, only KeyNet reported on that task as used later in Section \ref{['sec:exp-faiss']}). Lower right corner (i.e., high MRR and low $\mathcal{E}_\text{rel}$) is best.
  • Figure 4: Results for approximate search integration as described in \ref{['sec:exp-faiss']}. We use KeyNet to predict the actual top key for $\mathbf{x}$ found in the entire HotpotQA set of $\approx5.2$M keys, and run approximate search with a FAISS IVF index. We then query the database with $\hat{\mathbf{y}}(\mathbf{x})$ instead of $\mathbf{x}$ under growing search budgets. Compared to submitting directly the prompt to FAISS, our approaches always pays a forward cost price needed to evaluate the prediction (which grows roughly linearly with $\rho$, model size) and then run fast MIPS search on that prediction with varying number of $n_{\text{probe}}$, collecting cumulated FLOPs count for our method, and raw flop counts for FAISS when applied directly on $\mathbf{x}$. Note that 0.01% is here equivalent to a recall at approximately 500 given the original size of HotpotQA.
  • Figure 5: Results for approximate search integration as described in \ref{['sec:exp-faiss']} on Quora.
  • ...and 3 more figures