Test-Time Training on Nearest Neighbors for Large Language Models
Moritz Hardt, Yu Sun
TL;DR
The paper introduces test-time training on nearest neighbors (TTT-NN) for language modeling by retrieving neighbors from a large-scale embedding index and fine-tuning the model on their text for each test instance. A distributed FAISS-based index over the Pile enables near real-time neighbor retrieval, and the model is fine-tuned sequentially on neighbors with no hyperparameter tuning beyond defaults. Across 22 Pile tasks and multiple GPT-family models, TTT-NN yields substantial perplexity reductions with as few as 20 neighbors, often approaching the performance of much larger models while increasing inference compute. The work provides a practical baseline for test-time training in language modeling, with strongest gains in code generation tasks and a clear framework for scalability and cost considerations.
Abstract
Many recent efforts augment language models with retrieval, by adding retrieved data to the input context. For this approach to succeed, the retrieved data must be added at both training and test time. Moreover, as input length grows linearly with the size of retrieved data, cost in computation and memory grows quadratically for modern Transformers. To avoid these complications, we simply fine-tune the model on retrieved data at test time, using its standard training setup. We build a large-scale distributed index based on text embeddings of the Pile dataset. For each test input, our system retrieves its neighbors and fine-tunes the model on their text. Surprisingly, retrieving and training on as few as 20 neighbors, each for only one gradient iteration, drastically improves performance across more than 20 language modeling tasks in the Pile. For example, test-time training with nearest neighbors significantly narrows the performance gap between a small GPT-2 and a GPT-Neo model more than 10 times larger. Sufficient index quality and size, however, are necessary. Our work establishes a first baseline of test-time training for language modeling.
