Table of Contents
Fetching ...

Adversarial Moment-Matching Distillation of Large Language Models

Chen Jia

TL;DR

This paper reframes knowledge distillation for large language models as imitation learning by matching action-value moments rather than directly distilling probability distributions. It introduces an adversarial, two-player minimax training algorithm that jointly optimizes on-policy and off-policy moment-matching distances via learnable $Q$-value functions, guided by policy-gradient updates. Empirical results on instruction-following and task-specific tasks show state-of-the-art performance, demonstrating the effectiveness of moment-matching over traditional distribution distances. The approach offers a principled way to capture long-horizon knowledge transfer and could enhance practicality for deploying smaller, efficient LLMs without large sacrifices in performance.

Abstract

Knowledge distillation (KD) has been shown to be highly effective in guiding a student model with a larger teacher model and achieving practical benefits in improving the computational and memory efficiency for large language models (LLMs). State-of-the-art KD methods for LLMs mostly rely on minimizing explicit distribution distance between teacher and student probability predictions. Instead of optimizing these mandatory behaviour cloning objectives, we explore an imitation learning strategy for KD of LLMs. In particular, we minimize the imitation gap by matching the action-value moments of the teacher's behavior from both on- and off-policy perspectives. To achieve this action-value moment-matching goal, we propose an adversarial training algorithm to jointly estimate the moment-matching distance and optimize the student policy to minimize it. Results from both task-agnostic instruction-following experiments and task-specific experiments demonstrate the effectiveness of our method and achieve new state-of-the-art performance.

Adversarial Moment-Matching Distillation of Large Language Models

TL;DR

This paper reframes knowledge distillation for large language models as imitation learning by matching action-value moments rather than directly distilling probability distributions. It introduces an adversarial, two-player minimax training algorithm that jointly optimizes on-policy and off-policy moment-matching distances via learnable -value functions, guided by policy-gradient updates. Empirical results on instruction-following and task-specific tasks show state-of-the-art performance, demonstrating the effectiveness of moment-matching over traditional distribution distances. The approach offers a principled way to capture long-horizon knowledge transfer and could enhance practicality for deploying smaller, efficient LLMs without large sacrifices in performance.

Abstract

Knowledge distillation (KD) has been shown to be highly effective in guiding a student model with a larger teacher model and achieving practical benefits in improving the computational and memory efficiency for large language models (LLMs). State-of-the-art KD methods for LLMs mostly rely on minimizing explicit distribution distance between teacher and student probability predictions. Instead of optimizing these mandatory behaviour cloning objectives, we explore an imitation learning strategy for KD of LLMs. In particular, we minimize the imitation gap by matching the action-value moments of the teacher's behavior from both on- and off-policy perspectives. To achieve this action-value moment-matching goal, we propose an adversarial training algorithm to jointly estimate the moment-matching distance and optimize the student policy to minimize it. Results from both task-agnostic instruction-following experiments and task-specific experiments demonstrate the effectiveness of our method and achieve new state-of-the-art performance.
Paper Structure (25 sections, 3 theorems, 17 equations, 8 figures, 5 tables, 1 algorithm)

This paper contains 25 sections, 3 theorems, 17 equations, 8 figures, 5 tables, 1 algorithm.

Key Result

Proposition 1

Let $\mathcal{F}_Q$ denote the set of $Q$-value functions induced by sampling actions from $\pi_{\theta}$, then we have: In the following sections, we will use $\mathcal{U}^{\rm off}(\tau,f,\theta)$ to represent a sampled off-policy imitation gap with an trajectory $\tau \sim \pi_*|\boldsymbol{y}_{0}=\boldsymbol{x}$ w.r.t. a $Q$-value function $f$ and a teacher policy $\pi_{\theta}$.

Figures (8)

  • Figure 1: The comparison between the distribution-matching-based distillation and the action-value moment-matching distillation is outlined. $\pi_{\theta}$ and $\pi_{*}$ denote the student policy and the teacher policy, respectively. For both on-policy (using student-generated outputs) and off-policy (using teacher-generated outputs) perspectives, our approach optimizes moment-matching of action-value functions ($Q$-functions) instead of minimizing the distribution distance measured by $\mathcal{M}$ = KL, RKL, TV, etc.
  • Figure 2: Performance of difference step-wise distribution distances.
  • Figure 3: Adversarial training procedure for optimizing the on-policy and off-policy moment-matching distances $d^{\rm on}_{\rm MM}$, $d^{\rm off}_{\rm MM}$ on the instruction-following dataset.
  • Figure 4: Performance of difference step-wise distribution distances on five instruction-following datasets based on OpenLLaMA-3B $\rightarrow$ OpenLLaMA-7B.
  • Figure 5: Performance of difference step-wise distribution distances on three task-specific datasets based on (m)T5-XL $\rightarrow$ (m)T5-Base.
  • ...and 3 more figures

Theorems & Definitions (14)

  • Definition 1: Imitation gap
  • Proposition 1: Off-policy bound of imitation gap
  • proof
  • Proposition 2: On-policy bound of imitation gap
  • proof
  • Definition 2: Generalized step-wise distribution distance
  • Definition 3: Distribution-matching formulation of moment-matching bounds
  • Corollary 1
  • proof
  • proof
  • ...and 4 more