Table of Contents
Fetching ...

Optimal Completion Distillation for Sequence Learning

Sara Sabour, William Chan, Mohammad Norouzi

TL;DR

Optimal Completion Distillation (OCD) reframes seq2seq training by optimizing for edit-distance through exact optimal suffixes computed per generated prefix via dynamic programming. It constructs an optimal next-token policy from these Q-values and distills it into the model with a KL loss, avoiding MLE pretraining and joint likelihood objectives. On WSJ and Librispeech, OCD achieves state-of-the-art end-to-end speech results without language-model rescoring, demonstrating strong generalization and stability. The method is hyperparameter-free and accommodates on-policy or off-policy trajectories, with an efficient DP-based calculation for exact Q-values.

Abstract

We present Optimal Completion Distillation (OCD), a training procedure for optimizing sequence to sequence models based on edit distance. OCD is efficient, has no hyper-parameters of its own, and does not require pretraining or joint optimization with conditional log-likelihood. Given a partial sequence generated by the model, we first identify the set of optimal suffixes that minimize the total edit distance, using an efficient dynamic programming algorithm. Then, for each position of the generated sequence, we use a target distribution that puts equal probability on the first token of all the optimal suffixes. OCD achieves the state-of-the-art performance on end-to-end speech recognition, on both Wall Street Journal and Librispeech datasets, achieving $9.3\%$ WER and $4.5\%$ WER respectively.

Optimal Completion Distillation for Sequence Learning

TL;DR

Optimal Completion Distillation (OCD) reframes seq2seq training by optimizing for edit-distance through exact optimal suffixes computed per generated prefix via dynamic programming. It constructs an optimal next-token policy from these Q-values and distills it into the model with a KL loss, avoiding MLE pretraining and joint likelihood objectives. On WSJ and Librispeech, OCD achieves state-of-the-art end-to-end speech results without language-model rescoring, demonstrating strong generalization and stability. The method is hyperparameter-free and accommodates on-policy or off-policy trajectories, with an efficient DP-based calculation for exact Q-values.

Abstract

We present Optimal Completion Distillation (OCD), a training procedure for optimizing sequence to sequence models based on edit distance. OCD is efficient, has no hyper-parameters of its own, and does not require pretraining or joint optimization with conditional log-likelihood. Given a partial sequence generated by the model, we first identify the set of optimal suffixes that minimize the total edit distance, using an efficient dynamic programming algorithm. Then, for each position of the generated sequence, we use a target distribution that puts equal probability on the first token of all the optimal suffixes. OCD achieves the state-of-the-art performance on end-to-end speech recognition, on both Wall Street Journal and Librispeech datasets, achieving WER and WER respectively.

Paper Structure

This paper contains 14 sections, 1 theorem, 8 equations, 6 figures, 6 tables, 1 algorithm.

Key Result

Lemma 1

The edit distance resulting from any potential suffix $\boldsymbol{\mathbf{y}} \in \mathcal{Y}$ is lower bounded by $m_i$,

Figures (6)

  • Figure 1: Fraction of OCD training prefix tokens on WSJ which does not match ground truth.
  • Figure 2: WSJ validation Character Error Rate (CER) per training CER for MLE and OCD.
  • Figure 3: Librispeech training and validation WER per training epoch for OCD and MLE.
  • Figure B.1: Illustration of different training strategies for autoregressive sequence models. (a) Teacher Forcing: the model conditions on correct prefixes and is taught to predict the next ground truth token. (b) Scheduled Sampling: the model conditions on tokens either from ground truth or drawn from the model and is taught to predict the next ground truth token regardless. (c) Policy Gradient: the model conditions on prefixes drawn from the model and is encouraged to reinforce sequences with a large sequence reward $R(\tilde{y})$. (d) Optimal Completion Distillation: the model conditions on prefixes drawn from the model and is taught to predict an optimal completion policy $\pi^*$ specific to the prefix.
  • Figure B.2: Word Error Rate (WER) of WSJ with MLE, SS and OCD for different beam sizes.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Lemma 1
  • proof