Table of Contents
Fetching ...

A Fresh Take on Stale Embeddings: Improving Dense Retriever Training with Corrector Networks

Nicholas Monath, Will Grathwohl, Michael Boratko, Rob Fergus, Andrew McCallum, Manzil Zaheer

TL;DR

The approach matches state-of-the-art results even when no target embedding updates are made during training beyond an initial cache from the unsupervised pre-trained model, providing a 4-80x reduction in re-embedding computational cost.

Abstract

In dense retrieval, deep encoders provide embeddings for both inputs and targets, and the softmax function is used to parameterize a distribution over a large number of candidate targets (e.g., textual passages for information retrieval). Significant challenges arise in training such encoders in the increasingly prevalent scenario of (1) a large number of targets, (2) a computationally expensive target encoder model, (3) cached target embeddings that are out-of-date due to ongoing training of target encoder parameters. This paper presents a simple and highly scalable response to these challenges by training a small parametric corrector network that adjusts stale cached target embeddings, enabling an accurate softmax approximation and thereby sampling of up-to-date high scoring "hard negatives." We theoretically investigate the generalization properties of our proposed target corrector, relating the complexity of the network, staleness of cached representations, and the amount of training data. We present experimental results on large benchmark dense retrieval datasets as well as on QA with retrieval augmented language models. Our approach matches state-of-the-art results even when no target embedding updates are made during training beyond an initial cache from the unsupervised pre-trained model, providing a 4-80x reduction in re-embedding computational cost.

A Fresh Take on Stale Embeddings: Improving Dense Retriever Training with Corrector Networks

TL;DR

The approach matches state-of-the-art results even when no target embedding updates are made during training beyond an initial cache from the unsupervised pre-trained model, providing a 4-80x reduction in re-embedding computational cost.

Abstract

In dense retrieval, deep encoders provide embeddings for both inputs and targets, and the softmax function is used to parameterize a distribution over a large number of candidate targets (e.g., textual passages for information retrieval). Significant challenges arise in training such encoders in the increasingly prevalent scenario of (1) a large number of targets, (2) a computationally expensive target encoder model, (3) cached target embeddings that are out-of-date due to ongoing training of target encoder parameters. This paper presents a simple and highly scalable response to these challenges by training a small parametric corrector network that adjusts stale cached target embeddings, enabling an accurate softmax approximation and thereby sampling of up-to-date high scoring "hard negatives." We theoretically investigate the generalization properties of our proposed target corrector, relating the complexity of the network, staleness of cached representations, and the amount of training data. We present experimental results on large benchmark dense retrieval datasets as well as on QA with retrieval augmented language models. Our approach matches state-of-the-art results even when no target embedding updates are made during training beyond an initial cache from the unsupervised pre-trained model, providing a 4-80x reduction in re-embedding computational cost.
Paper Structure (19 sections, 6 theorems, 22 equations, 5 figures, 7 tables, 2 algorithms)

This paper contains 19 sections, 6 theorems, 22 equations, 5 figures, 7 tables, 2 algorithms.

Key Result

Lemma 4.0

Given a target encoder $g$ and its stale approximation $g'$, the gap between the true population risk and stale population risk is bounded in the following way: where $\mathcal{W}$ is the Wasserstein distance. Furthermore, if the approximation $g'$ comes from the same neural model as $g$ with parameters perturbed by $u$ as in aforementioned stale approximation, we have: $\|g-g'\|_1 \leq L\|u\|$ w

Figures (5)

  • Figure 1: Target Corrector Networks. The corrector network, $h(\cdot)$, moves the approximate representations of targets, $g'(\cdot)$ to be closer to their true positions, $g(\cdot)$. The corrector network is trained to approximate how the targets are transformed from $g'(\cdot)$ to $g(\cdot)$.
  • Figure 2: NQ Test Recall@1. We show the computational trade-offs between the amount of re-embedding during training and the task performance (GTR initialization). Our proposed target corrector approach achieves matching task performance at a fraction of the computational expense.
  • Figure 3: Training Sample Size The figure shows the trade-offs between the complexity of $h$ (parameter count), the approximation error, $KL(P||P_h)$, and the fraction of samples used for training. The left hand side shows only somewhat stale representations. The right hand side shows significantly stale representations. Using a higher fraction of training samples is needed with more staleness.
  • Figure 4: Parameter Count We plot the KL divergence using the stale embeddings ($\operatorname{KL}(P \Vert P_{g'})$, on the x-axis) against that of the trained correction models ($\operatorname{KL}(P \Vert P_h)$, on the y-axis). Parity is indicated by the dashed red line, demonstrating that the trained correction model is significantly better than using the stale embeddings. Increasing the parameter count of $h$ is more important when the discrepancy between stale and current embeddings is higher, indicated by a larger improvement toward the right of this plot.
  • Figure 5: A toy experiment where stale and true targets are distributed around the unit circle, and the corrected targets based on the learned approximation are also depicted. The associated distribution over targets ($\beta=20$) based on the point identified.

Theorems & Definitions (7)

  • Lemma 4.0
  • Lemma 4.0
  • Theorem 4.1
  • Lemma 1.0
  • Lemma 1.0
  • Theorem 1.1
  • proof