Table of Contents
Fetching ...

Cascade-Aware Training of Language Models

Congchao Wang, Sean Augenstein, Keith Rush, Wittawat Jitkrittum, Harikrishna Narasimhan, Ankit Singh Rawat, Aditya Krishna Menon, Alec Go

TL;DR

Cascade-Aware Training (CAT) addresses the problem of deploying cost-efficient cascades of language models by training the small model with awareness of its place in the cascade and the downstream large model’s capabilities. CAT introduces a token-level, cascade-aware loss that selectively weights tokens based on whether either model can predict them correctly, guiding the small model to focus on easier tokens while improving deferral reliability. Empirical results across SuperGLUE, WMT22, and FLAN2021 on two PALM-2 variants show substantial improvements in the cascade’s quality-cost curve, including notable FLOPs reductions with maintained or improved accuracy and BLEU. The work demonstrates a scalable, plug-in training objective that enhances cascaded LMs in multi-task generative/classification settings and suggests further extensions to edge deployments and distributed training.

Abstract

Reducing serving cost and latency is a fundamental concern for the deployment of language models (LMs) in business applications. To address this, cascades of LMs offer an effective solution that conditionally employ smaller models for simpler queries. Cascaded systems are typically built with independently trained models, neglecting the advantages of considering inference-time interactions of the cascaded LMs during training. In this paper, we present cascade-aware training(CAT), an approach to optimizing the overall quality-cost performance tradeoff of a cascade of LMs. We achieve inference-time benefits by training the small LM with awareness of its place in a cascade and downstream capabilities. We demonstrate the value of the proposed method with over 60 LM tasks of the SuperGLUE, WMT22, and FLAN2021 datasets.

Cascade-Aware Training of Language Models

TL;DR

Cascade-Aware Training (CAT) addresses the problem of deploying cost-efficient cascades of language models by training the small model with awareness of its place in the cascade and the downstream large model’s capabilities. CAT introduces a token-level, cascade-aware loss that selectively weights tokens based on whether either model can predict them correctly, guiding the small model to focus on easier tokens while improving deferral reliability. Empirical results across SuperGLUE, WMT22, and FLAN2021 on two PALM-2 variants show substantial improvements in the cascade’s quality-cost curve, including notable FLOPs reductions with maintained or improved accuracy and BLEU. The work demonstrates a scalable, plug-in training objective that enhances cascaded LMs in multi-task generative/classification settings and suggests further extensions to edge deployments and distributed training.

Abstract

Reducing serving cost and latency is a fundamental concern for the deployment of language models (LMs) in business applications. To address this, cascades of LMs offer an effective solution that conditionally employ smaller models for simpler queries. Cascaded systems are typically built with independently trained models, neglecting the advantages of considering inference-time interactions of the cascaded LMs during training. In this paper, we present cascade-aware training(CAT), an approach to optimizing the overall quality-cost performance tradeoff of a cascade of LMs. We achieve inference-time benefits by training the small LM with awareness of its place in a cascade and downstream capabilities. We demonstrate the value of the proposed method with over 60 LM tasks of the SuperGLUE, WMT22, and FLAN2021 datasets.
Paper Structure (28 sections, 14 equations, 72 figures, 1 table)

This paper contains 28 sections, 14 equations, 72 figures, 1 table.

Figures (72)

  • Figure 1: Left: Cascade setup at inference time. The small model is deployed along side the large model that guided its training. Right: Our proposed cascade-aware training (CAT). The small model has access to a trained, fixed, large model during training. The proposed training objective (see \ref{['eq:cat_loss']}) is a generalization of the standard one-hot cross entropy and KL-divergence based distillation loss, where losses are only accounted for on tokens that are predicted correctly by the small or the large model (i.e., tokens that are not too difficult).
  • Figure 2: SuperGLUE
  • Figure 3: FLAN2021-Cls-Tasks
  • Figure 5: WMT22
  • Figure 6: FLAN2021-Gen-Tasks
  • ...and 67 more figures