Table of Contents
Fetching ...

Test-Time Training on Nearest Neighbors for Large Language Models

Moritz Hardt, Yu Sun

TL;DR

The paper introduces test-time training on nearest neighbors (TTT-NN) for language modeling by retrieving neighbors from a large-scale embedding index and fine-tuning the model on their text for each test instance. A distributed FAISS-based index over the Pile enables near real-time neighbor retrieval, and the model is fine-tuned sequentially on neighbors with no hyperparameter tuning beyond defaults. Across 22 Pile tasks and multiple GPT-family models, TTT-NN yields substantial perplexity reductions with as few as 20 neighbors, often approaching the performance of much larger models while increasing inference compute. The work provides a practical baseline for test-time training in language modeling, with strongest gains in code generation tasks and a clear framework for scalability and cost considerations.

Abstract

Many recent efforts augment language models with retrieval, by adding retrieved data to the input context. For this approach to succeed, the retrieved data must be added at both training and test time. Moreover, as input length grows linearly with the size of retrieved data, cost in computation and memory grows quadratically for modern Transformers. To avoid these complications, we simply fine-tune the model on retrieved data at test time, using its standard training setup. We build a large-scale distributed index based on text embeddings of the Pile dataset. For each test input, our system retrieves its neighbors and fine-tunes the model on their text. Surprisingly, retrieving and training on as few as 20 neighbors, each for only one gradient iteration, drastically improves performance across more than 20 language modeling tasks in the Pile. For example, test-time training with nearest neighbors significantly narrows the performance gap between a small GPT-2 and a GPT-Neo model more than 10 times larger. Sufficient index quality and size, however, are necessary. Our work establishes a first baseline of test-time training for language modeling.

Test-Time Training on Nearest Neighbors for Large Language Models

TL;DR

The paper introduces test-time training on nearest neighbors (TTT-NN) for language modeling by retrieving neighbors from a large-scale embedding index and fine-tuning the model on their text for each test instance. A distributed FAISS-based index over the Pile enables near real-time neighbor retrieval, and the model is fine-tuned sequentially on neighbors with no hyperparameter tuning beyond defaults. Across 22 Pile tasks and multiple GPT-family models, TTT-NN yields substantial perplexity reductions with as few as 20 neighbors, often approaching the performance of much larger models while increasing inference compute. The work provides a practical baseline for test-time training in language modeling, with strongest gains in code generation tasks and a clear framework for scalability and cost considerations.

Abstract

Many recent efforts augment language models with retrieval, by adding retrieved data to the input context. For this approach to succeed, the retrieved data must be added at both training and test time. Moreover, as input length grows linearly with the size of retrieved data, cost in computation and memory grows quadratically for modern Transformers. To avoid these complications, we simply fine-tune the model on retrieved data at test time, using its standard training setup. We build a large-scale distributed index based on text embeddings of the Pile dataset. For each test input, our system retrieves its neighbors and fine-tunes the model on their text. Surprisingly, retrieving and training on as few as 20 neighbors, each for only one gradient iteration, drastically improves performance across more than 20 language modeling tasks in the Pile. For example, test-time training with nearest neighbors significantly narrows the performance gap between a small GPT-2 and a GPT-Neo model more than 10 times larger. Sufficient index quality and size, however, are necessary. Our work establishes a first baseline of test-time training for language modeling.
Paper Structure (19 sections, 14 figures, 1 table)

This paper contains 19 sections, 14 figures, 1 table.

Figures (14)

  • Figure 1: System architecture for test-time training with nearest neighbors (TTT-NN).
  • Figure 2: Comparison of different models before and after TTT-NN with $50$ neighbors. Left: All Pile tasks. Center: Best performing Pile task ( pile_github). Right: Largest Pile task ( pile_pile-cc).
  • Figure 3: Left: Histogram of distances to nearest neighbor for $10,000$ random queries from the validation set. Center: Focusing on the $400$ smallest distances. Right: Distances between $10,000$ pairs of random queries.
  • Figure 4: Left: Mean distances to the nearest neighbor as index size grows. Right: Mean distances with standard deviation shown as shaded region for the full size index.
  • Figure 5: Bits per byte results on all Pile tasks for a small GPT-2 model (117M parameters) before and after test-time training on $50$ nearest neighbors.
  • ...and 9 more figures