Table of Contents
Fetching ...

LMPriors: Pre-Trained Language Models as Task-Specific Priors

Kristy Choi, Chris Cundy, Sanjari Srivastava, Stefano Ermon

TL;DR

Language Model Priors (LMPriors) propose to extract task-specific priors from pretrained LMs using natural-language metadata to bias downstream learning, particularly in low-data regimes. The method formulates prompts that translate metadata into priors to guide feature selection, causal discovery, and safe reinforcement learning. Empirical results show substantial gains in feature selection under data corruption, improved safety in RL via reward shaping, and state-of-the-art-like performance in causal direction discovery when combined with data-driven methods. The work highlights both the promise and the risks of prompt-based priors, underscoring the need for careful prompt design and human oversight.

Abstract

Particularly in low-data regimes, an outstanding challenge in machine learning is developing principled techniques for augmenting our models with suitable priors. This is to encourage them to learn in ways that are compatible with our understanding of the world. But in contrast to generic priors such as shrinkage or sparsity, we draw inspiration from the recent successes of large-scale language models (LMs) to construct task-specific priors distilled from the rich knowledge of LMs. Our method, Language Model Priors (LMPriors), incorporates auxiliary natural language metadata about the task -- such as variable names and descriptions -- to encourage downstream model outputs to be consistent with the LM's common-sense reasoning based on the metadata. Empirically, we demonstrate that LMPriors improve model performance in settings where such natural language descriptions are available, and perform well on several tasks that benefit from such prior knowledge, such as feature selection, causal inference, and safe reinforcement learning.

LMPriors: Pre-Trained Language Models as Task-Specific Priors

TL;DR

Language Model Priors (LMPriors) propose to extract task-specific priors from pretrained LMs using natural-language metadata to bias downstream learning, particularly in low-data regimes. The method formulates prompts that translate metadata into priors to guide feature selection, causal discovery, and safe reinforcement learning. Empirical results show substantial gains in feature selection under data corruption, improved safety in RL via reward shaping, and state-of-the-art-like performance in causal direction discovery when combined with data-driven methods. The work highlights both the promise and the risks of prompt-based priors, underscoring the need for careful prompt design and human oversight.

Abstract

Particularly in low-data regimes, an outstanding challenge in machine learning is developing principled techniques for augmenting our models with suitable priors. This is to encourage them to learn in ways that are compatible with our understanding of the world. But in contrast to generic priors such as shrinkage or sparsity, we draw inspiration from the recent successes of large-scale language models (LMs) to construct task-specific priors distilled from the rich knowledge of LMs. Our method, Language Model Priors (LMPriors), incorporates auxiliary natural language metadata about the task -- such as variable names and descriptions -- to encourage downstream model outputs to be consistent with the LM's common-sense reasoning based on the metadata. Empirically, we demonstrate that LMPriors improve model performance in settings where such natural language descriptions are available, and perform well on several tasks that benefit from such prior knowledge, such as feature selection, causal inference, and safe reinforcement learning.
Paper Structure (29 sections, 3 equations, 8 figures, 2 tables)

This paper contains 29 sections, 3 equations, 8 figures, 2 tables.

Figures (8)

  • Figure 1: A flowchart of the Language Model Prior (LMPrior) framework. We leverage the rich knowledge base of a pretrained LM to incorporate task-relevant prior knowledge into our learning algorithm $f$. Our method uses natural language metadata $\mathcal{D}_{\textrm{meta}}$ to return a specialized learner $\tilde{f}$, whose outputs given the dataset ${\mathcal{D}}$ are encouraged to remain consistent with both the metadata and real-world knowledge as distilled in the LM.
  • Figure 2: An example of a prompt used in LMPriors for the feature selection task in Section \ref{['exp-census']}. The prompt ${\mathbf{c}}$ consists of a textual description of the feature selection task, the variable name, a short description of the variable, and the correct answer followed by an explanation. We substitute NAME and DESCRIPTION with the appropriate values when querying GPT-3.
  • Figure 3: Results for the variable separation experiment. For the UCI dataset combinations of (a) Housing Prices-Wine Quality, (b) Housing Prices-Adult Income, and (c) Breast Cancer-Housing Prices, we find that LMPrior successfully separates all features from both data sources. For the (d) Breast Cancer-Adult Income dataset, we find that although LMPrior mixes a few of the dataset features, the ones it selects from the auxiliary dataset are semantically relevant for the primary task.
  • Figure 4: Comparison of LassoNet JMLR:v22:20-848 with LMPrior on the feature separation task for the UCI Breast Cancer-Wine Quality dataset combination. Features are ordered according to importance. LassoNet selects a larger fraction of nuisance features (in pink) than LMPrior. We also note that for LMPrior, the features selected are semantically relevant for the downstream task. Some features returned by LassoNet are tied in importance.
  • Figure 5: The Island Navigation gridworld as in leike2017ai. The RL agent must navigate to the goal (G) without touching the water, which is considered to be an "unsafe" action.
  • ...and 3 more figures