Table of Contents
Fetching ...

Matryoshka Model Learning for Improved Elastic Student Models

Chetan Verma, Aditya Srinivas Timmaraju, Cho-Jui Hsieh, Suyash Damle, Ngot Bui, Yang Zhang, Wen Chen, Xin Liu, Prateek Jain, Inderjit S Dhillon

TL;DR

MatTA introduces a Teacher-TA-Student distillation framework that nests a larger TA inside the serving Student, enabling elastic generation of multiple servable models from a single training run. The method uses online distillation with a composite loss that includes Student, TA, and TA-to-Student distillation terms, trained with a second-order Shampoo optimizer. It demonstrates significant practical impact, including 20% live metric improvement in production and notable gains on GPT-2 Medium benchmarks, with a range of sub-models extractable via Mix'n'Match. Ablation studies show super-additive benefits from combining MatTA with Shampoo and explore parameter sharing trade-offs. Overall, MatTA provides an executable path to elastic, higher-quality, resource-conscious serving models.

Abstract

Industry-grade ML models are carefully designed to meet rapidly evolving serving constraints, which requires significant resources for model development. In this paper, we propose MatTA, a framework for training multiple accurate Student models using a novel Teacher-TA-Student recipe. TA models are larger versions of the Student models with higher capacity, and thus allow Student models to better relate to the Teacher model and also bring in more domain-specific expertise. Furthermore, multiple accurate Student models can be extracted from the TA model. Therefore, despite only one training run, our methodology provides multiple servable options to trade off accuracy for lower serving cost. We demonstrate the proposed method, MatTA, on proprietary datasets and models. Its practical efficacy is underscored by live A/B tests within a production ML system, demonstrating 20% improvement on a key metric. We also demonstrate our method on GPT-2 Medium, a public model, and achieve relative improvements of over 24% on SAT Math and over 10% on the LAMBADA benchmark.

Matryoshka Model Learning for Improved Elastic Student Models

TL;DR

MatTA introduces a Teacher-TA-Student distillation framework that nests a larger TA inside the serving Student, enabling elastic generation of multiple servable models from a single training run. The method uses online distillation with a composite loss that includes Student, TA, and TA-to-Student distillation terms, trained with a second-order Shampoo optimizer. It demonstrates significant practical impact, including 20% live metric improvement in production and notable gains on GPT-2 Medium benchmarks, with a range of sub-models extractable via Mix'n'Match. Ablation studies show super-additive benefits from combining MatTA with Shampoo and explore parameter sharing trade-offs. Overall, MatTA provides an executable path to elastic, higher-quality, resource-conscious serving models.

Abstract

Industry-grade ML models are carefully designed to meet rapidly evolving serving constraints, which requires significant resources for model development. In this paper, we propose MatTA, a framework for training multiple accurate Student models using a novel Teacher-TA-Student recipe. TA models are larger versions of the Student models with higher capacity, and thus allow Student models to better relate to the Teacher model and also bring in more domain-specific expertise. Furthermore, multiple accurate Student models can be extracted from the TA model. Therefore, despite only one training run, our methodology provides multiple servable options to trade off accuracy for lower serving cost. We demonstrate the proposed method, MatTA, on proprietary datasets and models. Its practical efficacy is underscored by live A/B tests within a production ML system, demonstrating 20% improvement on a key metric. We also demonstrate our method on GPT-2 Medium, a public model, and achieve relative improvements of over 24% on SAT Math and over 10% on the LAMBADA benchmark.

Paper Structure

This paper contains 24 sections, 2 equations, 5 figures, 4 tables, 2 algorithms.

Figures (5)

  • Figure 1: (a) MatTA: a novel elastic distillation framework. In order to generate elastic servable Student models from a single training run, we create a Teaching Assistant (TA) model from the given serving model. TA and Student are co-trained with Student learning from original (Teacher's) labels as well as distilling from TA. (b) MatTA can extract a range of compelling Student models, each of which surpass an independently trained original Student model. Experiments show significant improvement even while retaining identical architecture.
  • Figure 2: M-Nesting a Dense layer. (Left) M-Nested Dense layer with one input. This is the situation when this layer is the first one to get M-nested. (Right) M-Nested Dense layer with two inputs. This is the general case throughout the rest of the model after at least one layer has already been M-nested.
  • Figure 3: M-nesting of a GAU unit gauunit
  • Figure 4: The Shampoo preconditioner matrix for MatTA. We observe that the preconditioner automatically captures parameter correlations within the Student and TA models, which leads to superior performance for training MatTA.
  • Figure 5: Performance of extracted sub-models from MatTA GPT2-Medium on HellaSwag