Cascade-Aware Training of Language Models
Congchao Wang, Sean Augenstein, Keith Rush, Wittawat Jitkrittum, Harikrishna Narasimhan, Ankit Singh Rawat, Aditya Krishna Menon, Alec Go
TL;DR
Cascade-Aware Training (CAT) addresses the problem of deploying cost-efficient cascades of language models by training the small model with awareness of its place in the cascade and the downstream large model’s capabilities. CAT introduces a token-level, cascade-aware loss that selectively weights tokens based on whether either model can predict them correctly, guiding the small model to focus on easier tokens while improving deferral reliability. Empirical results across SuperGLUE, WMT22, and FLAN2021 on two PALM-2 variants show substantial improvements in the cascade’s quality-cost curve, including notable FLOPs reductions with maintained or improved accuracy and BLEU. The work demonstrates a scalable, plug-in training objective that enhances cascaded LMs in multi-task generative/classification settings and suggests further extensions to edge deployments and distributed training.
Abstract
Reducing serving cost and latency is a fundamental concern for the deployment of language models (LMs) in business applications. To address this, cascades of LMs offer an effective solution that conditionally employ smaller models for simpler queries. Cascaded systems are typically built with independently trained models, neglecting the advantages of considering inference-time interactions of the cascaded LMs during training. In this paper, we present cascade-aware training(CAT), an approach to optimizing the overall quality-cost performance tradeoff of a cascade of LMs. We achieve inference-time benefits by training the small LM with awareness of its place in a cascade and downstream capabilities. We demonstrate the value of the proposed method with over 60 LM tasks of the SuperGLUE, WMT22, and FLAN2021 datasets.
