Matryoshka Model Learning for Improved Elastic Student Models
Chetan Verma, Aditya Srinivas Timmaraju, Cho-Jui Hsieh, Suyash Damle, Ngot Bui, Yang Zhang, Wen Chen, Xin Liu, Prateek Jain, Inderjit S Dhillon
TL;DR
MatTA introduces a Teacher-TA-Student distillation framework that nests a larger TA inside the serving Student, enabling elastic generation of multiple servable models from a single training run. The method uses online distillation with a composite loss that includes Student, TA, and TA-to-Student distillation terms, trained with a second-order Shampoo optimizer. It demonstrates significant practical impact, including 20% live metric improvement in production and notable gains on GPT-2 Medium benchmarks, with a range of sub-models extractable via Mix'n'Match. Ablation studies show super-additive benefits from combining MatTA with Shampoo and explore parameter sharing trade-offs. Overall, MatTA provides an executable path to elastic, higher-quality, resource-conscious serving models.
Abstract
Industry-grade ML models are carefully designed to meet rapidly evolving serving constraints, which requires significant resources for model development. In this paper, we propose MatTA, a framework for training multiple accurate Student models using a novel Teacher-TA-Student recipe. TA models are larger versions of the Student models with higher capacity, and thus allow Student models to better relate to the Teacher model and also bring in more domain-specific expertise. Furthermore, multiple accurate Student models can be extracted from the TA model. Therefore, despite only one training run, our methodology provides multiple servable options to trade off accuracy for lower serving cost. We demonstrate the proposed method, MatTA, on proprietary datasets and models. Its practical efficacy is underscored by live A/B tests within a production ML system, demonstrating 20% improvement on a key metric. We also demonstrate our method on GPT-2 Medium, a public model, and achieve relative improvements of over 24% on SAT Math and over 10% on the LAMBADA benchmark.
