Table of Contents
Fetching ...

Why Do Pretrained Language Models Help in Downstream Tasks? An Analysis of Head and Prompt Tuning

Colin Wei, Sang Michael Xie, Tengyu Ma

TL;DR

The paper analyzes why pretrained language models help in downstream tasks by framing pretraining as learning a latent-variable generative text model. It examines head tuning and prompt tuning under Hidden Markov Models (HMMs) and memory-augmented HMMs, deriving theoretical guarantees for recovering downstream labels from posterior information. Key findings show that, under non-degeneracy, simple heads suffice in vanilla HMMs, while soft prompts relax these non-degeneracy conditions and enable recovery with weaker assumptions; memory-augmented models further strengthen recoverability through attention mechanisms. Empirical simulations on synthetic data corroborate the theory, highlighting practical benefits of prompt tuning and memory structures for downstream performance.

Abstract

Pretrained language models have achieved state-of-the-art performance when adapted to a downstream NLP task. However, theoretical analysis of these models is scarce and challenging since the pretraining and downstream tasks can be very different. We propose an analysis framework that links the pretraining and downstream tasks with an underlying latent variable generative model of text -- the downstream classifier must recover a function of the posterior distribution over the latent variables. We analyze head tuning (learning a classifier on top of the frozen pretrained model) and prompt tuning in this setting. The generative model in our analysis is either a Hidden Markov Model (HMM) or an HMM augmented with a latent memory component, motivated by long-term dependencies in natural language. We show that 1) under certain non-degeneracy conditions on the HMM, simple classification heads can solve the downstream task, 2) prompt tuning obtains downstream guarantees with weaker non-degeneracy conditions, and 3) our recovery guarantees for the memory-augmented HMM are stronger than for the vanilla HMM because task-relevant information is easier to recover from the long-term memory. Experiments on synthetically generated data from HMMs back our theoretical findings.

Why Do Pretrained Language Models Help in Downstream Tasks? An Analysis of Head and Prompt Tuning

TL;DR

The paper analyzes why pretrained language models help in downstream tasks by framing pretraining as learning a latent-variable generative text model. It examines head tuning and prompt tuning under Hidden Markov Models (HMMs) and memory-augmented HMMs, deriving theoretical guarantees for recovering downstream labels from posterior information. Key findings show that, under non-degeneracy, simple heads suffice in vanilla HMMs, while soft prompts relax these non-degeneracy conditions and enable recovery with weaker assumptions; memory-augmented models further strengthen recoverability through attention mechanisms. Empirical simulations on synthetic data corroborate the theory, highlighting practical benefits of prompt tuning and memory structures for downstream performance.

Abstract

Pretrained language models have achieved state-of-the-art performance when adapted to a downstream NLP task. However, theoretical analysis of these models is scarce and challenging since the pretraining and downstream tasks can be very different. We propose an analysis framework that links the pretraining and downstream tasks with an underlying latent variable generative model of text -- the downstream classifier must recover a function of the posterior distribution over the latent variables. We analyze head tuning (learning a classifier on top of the frozen pretrained model) and prompt tuning in this setting. The generative model in our analysis is either a Hidden Markov Model (HMM) or an HMM augmented with a latent memory component, motivated by long-term dependencies in natural language. We show that 1) under certain non-degeneracy conditions on the HMM, simple classification heads can solve the downstream task, 2) prompt tuning obtains downstream guarantees with weaker non-degeneracy conditions, and 3) our recovery guarantees for the memory-augmented HMM are stronger than for the vanilla HMM because task-relevant information is easier to recover from the long-term memory. Experiments on synthetically generated data from HMMs back our theoretical findings.

Paper Structure

This paper contains 25 sections, 16 theorems, 89 equations, 3 figures.

Key Result

Theorem 3.3

Assume that non-degeneracy (Assumption ass:non_degen_vanilla) and regularity (Assumption ass:regularity) hold. Then any downstream task $F^\star(x)$ of the form eq:hmm_downstream can be computed by a linear head on $G$ applied to a shifted sequence. That is, there exists linear head weights $b \in \ where $x' = (\varnothing, x_{1:t})$ is the concatenation of a special token $\varnothing$ with $x$.

Figures (3)

  • Figure 1: Left: Illustration of HMM graphical model. Right: Overview of the formulation and analysis setting for prompt (and head) tuning. To abstractify soft prompt tuning, we note that every token has a natural embedding, the corresponding row of the emission probability matrix. We view prompt tuning as adding a fake token $\widetilde{z}$ to the vocabulary, assigning it a row $u$ in the emission matrix, and prepending it to the input embedding sequence. More details are provided in Section \ref{['sec:prompt_tune']}.
  • Figure 2: Left: Memory-augmented HMM with a single memory cell. The memory $M$ and hidden state $H_i$ determine the emission probabilities for each state $X_i$. Right: Memory-augmented HMM with multiple memories $M_1, \ldots, M_{N}$. The hidden state $H_i$ consists of a cell index $J_i$ and syntax state $S_i$. To sample $X_i$, we first look up the $J_i$-th memory cell $M_{J_i}$. The token emission probability is then determined by the tuple $(M_{J_i}, J_i, S_i)$.
  • Figure 3: Left: Head vs. prompt tuning with a linear head on synthetically-generated HMM data, with varying hidden state sizes. Prompt tuning improves downstream accuracy especially when the problem is degenerate ($|\mathcal{H}| > |\mathcal{X}|$). Right: Downstream accuracy of head tuning on data from vanilla HMM vs. memory-augmented HMM, across varying values of $|{\mathcal{M}}| |\mathcal{H}|$. Long-term dependencies in the memory-augmented HMM data improve downstream recovery when using attention. Experiments average over 20 trials (left) and 5 trials (right) of pretraining and finetuning, with 95% intervals shown.

Theorems & Definitions (43)

  • Theorem 3.3
  • Proposition 3.4
  • Theorem 3.6
  • Example 4.1: Generating natural sentence with memory-augmented HMM
  • Theorem 4.3
  • Theorem 4.6
  • Claim A.1
  • proof
  • proof : Proof of Theorem \ref{['thm:hmm_simple']}
  • proof : Proof of Proposition \ref{['prop:ci']}
  • ...and 33 more