Table of Contents
Fetching ...

Q-Probe: A Lightweight Approach to Reward Maximization for Language Models

Kenneth Li, Samy Jelassi, Hugh Zhang, Sham Kakade, Martin Wattenberg, David Brandfonbrener

TL;DR

The paper presents Q-Probe, a lightweight approach to reward maximization for language models that freezes the base model and trains a small linear probe on embeddings to reweight candidate completions. At inference, it draws $k$ samples from the base LM and uses a $Q_\theta$-based softmax to select among them, with a theoretical link to KL-constrained optimization as $k$ grows. Training can proceed via reward modeling or direct policy learning (including importance-weighted policy gradients), and it extends to learning from human preferences; results show meaningful gains on coding benchmarks (MBPP, GSM-8K) and favorable performance when combined with other methods, even on API-based models. The method is particularly appealing for data- and compute-constrained settings, offering a practical middle ground between prompting and full finetuning, with potential for broad applicability across tasks and modalities. Overall, Q-Probe demonstrates that a small, well-ordered discriminator operating on embeddings can substantially improve task-specific rewards with limited training and flexible deployment.

Abstract

We present an approach called Q-probing to adapt a pre-trained language model to maximize a task-specific reward function. At a high level, Q-probing sits between heavier approaches such as finetuning and lighter approaches such as few shot prompting, but can also be combined with either. The idea is to learn a simple linear function on a model's embedding space that can be used to reweight candidate completions. We theoretically show that this sampling procedure is equivalent to a KL-constrained maximization of the Q-probe as the number of samples increases. To train the Q-probes we consider either reward modeling or a class of novel direct policy learning objectives based on importance weighted policy gradients. With this technique, we see gains in domains with ground-truth rewards (code generation) as well as implicit rewards defined by preference data, even outperforming finetuning in data-limited regimes. Moreover, a Q-probe can be trained on top of an API since it only assumes access to sampling and embeddings. Code: https://github.com/likenneth/q_probe .

Q-Probe: A Lightweight Approach to Reward Maximization for Language Models

TL;DR

The paper presents Q-Probe, a lightweight approach to reward maximization for language models that freezes the base model and trains a small linear probe on embeddings to reweight candidate completions. At inference, it draws samples from the base LM and uses a -based softmax to select among them, with a theoretical link to KL-constrained optimization as grows. Training can proceed via reward modeling or direct policy learning (including importance-weighted policy gradients), and it extends to learning from human preferences; results show meaningful gains on coding benchmarks (MBPP, GSM-8K) and favorable performance when combined with other methods, even on API-based models. The method is particularly appealing for data- and compute-constrained settings, offering a practical middle ground between prompting and full finetuning, with potential for broad applicability across tasks and modalities. Overall, Q-Probe demonstrates that a small, well-ordered discriminator operating on embeddings can substantially improve task-specific rewards with limited training and flexible deployment.

Abstract

We present an approach called Q-probing to adapt a pre-trained language model to maximize a task-specific reward function. At a high level, Q-probing sits between heavier approaches such as finetuning and lighter approaches such as few shot prompting, but can also be combined with either. The idea is to learn a simple linear function on a model's embedding space that can be used to reweight candidate completions. We theoretically show that this sampling procedure is equivalent to a KL-constrained maximization of the Q-probe as the number of samples increases. To train the Q-probes we consider either reward modeling or a class of novel direct policy learning objectives based on importance weighted policy gradients. With this technique, we see gains in domains with ground-truth rewards (code generation) as well as implicit rewards defined by preference data, even outperforming finetuning in data-limited regimes. Moreover, a Q-probe can be trained on top of an API since it only assumes access to sampling and embeddings. Code: https://github.com/likenneth/q_probe .
Paper Structure (31 sections, 4 theorems, 18 equations, 6 figures, 5 tables)

This paper contains 31 sections, 4 theorems, 18 equations, 6 figures, 5 tables.

Key Result

Theorem 4.1

Our policy approaches the following limit

Figures (6)

  • Figure 1: An illustration of the Q-probe inference procedure. Given a prompt, we use the language model to generate $k=3$ completions (in this case, programs) and the respective embeddings of the $k$ prompt-completion pairs. Then the linear Q-probe maps the features into the logits of a softmax distribution. We obtain our final sample from the Q-probe by sampling from this distribution.
  • Figure 2: How MBPP test reward scales with the size of the training dataset. At inference we fixing $K = 48$ and $\beta = 0.1$. Error bars show $95\%$ confidence interval over 10 training runs.
  • Figure 3: How MBPP test reward scales with inference-time compute when sweeping over $K$ with $\beta =0.1$. Error bars show $95\%$ confidence interval over 10 training runs.
  • Figure 4: Per-problem correlation between base model expected reward and Q-probe value (centered and normalized by standard deviation). Each point corresponds to a prompt in the training set and averages across the 200 sampled completions. $L_{PG}$ learns by contrasting completions to the same prompt, so it learns a probe that is less prompt-dependent.
  • Figure 5: How the win rate of Q-probe scales with inference-time compute on preference learning benchmarks. The skyline shows the performance of a perfect oracle selector. The shaded area represents $95\%$ confidence interval for $10$ runs.
  • ...and 1 more figures

Theorems & Definitions (7)

  • Theorem 4.1
  • Corollary 4.2
  • Remark 5.1
  • Remark 5.2
  • Theorem 1.1
  • proof
  • Corollary 1.2