Long-Tail Crisis in Nearest Neighbor Language Models
Yuto Nishida, Makoto Morishita, Hiroyuki Deguchi, Hidetaka Kamigaito, Taro Watanabe
TL;DR
This work investigates why $k$NN-LM fails to consistently improve predictions for low-frequency target tokens. By analyzing GPT2-XL with a large datastore on a resplit WikiText-103, it links prediction performance to retrieval fidelity and datastore properties, revealing that $k$NN probabilities are often lower than base LM probabilities for rare tokens, and that retrieval and quantization errors are amplified for long-tail targets. The authors identify four contributing factors—sparse datastore distributions, neighbor contamination, retrieval gaps, and larger PQ reconstruction errors—that undermine gains for low-frequency tokens, while high-frequency tokens benefit. These findings challenge the assumption that explicit memory universally aids long-tail phenomena and suggest concrete directions (e.g., frequency-aware weighting, inverse document frequency, Zipfian whitening) for enhancing retrieval-augmented LMs in handling rare tokens.
Abstract
The $k$-nearest-neighbor language model ($k$NN-LM), one of the retrieval-augmented language models, improves the perplexity for given text by directly accessing a large datastore built from any text data during inference. A widely held hypothesis for the success of $k$NN-LM is that its explicit memory, i.e., the datastore, enhances predictions for long-tail phenomena. However, prior works have primarily shown its ability to retrieve long-tail contexts, leaving the model's performance remain underexplored in estimating the probabilities of long-tail target tokens during inference. In this paper, we investigate the behavior of $k$NN-LM on low-frequency tokens, examining prediction probability, retrieval accuracy, token distribution in the datastore, and approximation error of the product quantization. Our experimental results reveal that $k$NN-LM does not improve prediction performance for low-frequency tokens but mainly benefits high-frequency tokens regardless of long-tail contexts in the datastore.
