Table of Contents
Fetching ...

Towards a theory of model distillation

Enric Boix-Adsera

TL;DR

This work formalizes model distillation through PAC-distillation, establishing a framework to analyze when a simpler model can approximate a trained, larger model under distributional data. It shows distillation can be cheaper than learning from scratch and develops a general theory of computational reductions and statistical bounds, including a novel Linear Representation Hypothesis (LRH) that enables distillation of networks into explicit decision trees in polynomial time. The paper presents two case studies—distilling networks into juntas and into decision trees—demonstrating both algorithmic feasibility and practical benefits, and provides a web of reductions to relate distillation across model classes. It also discusses robust statistical results, highlighting perfect vs. agnostic distillation, Pareto-frontier bounds, and limitations such as the non-characterization of agnostic-distillation sample complexity, while outlining extensions toward broader model classes and foundation models. Overall, the work lays foundational theory and practical algorithms for distillation with potential impact on interpretability, efficiency, and resource-sharing in machine learning systems.

Abstract

Distillation is the task of replacing a complicated machine learning model with a simpler model that approximates the original [BCNM06,HVD15]. Despite many practical applications, basic questions about the extent to which models can be distilled, and the runtime and amount of data needed to distill, remain largely open. To study these questions, we initiate a general theory of distillation, defining PAC-distillation in an analogous way to PAC-learning [Val84]. As applications of this theory: (1) we propose new algorithms to extract the knowledge stored in the trained weights of neural networks -- we show how to efficiently distill neural networks into succinct, explicit decision tree representations when possible by using the ``linear representation hypothesis''; and (2) we prove that distillation can be much cheaper than learning from scratch, and make progress on characterizing its complexity.

Towards a theory of model distillation

TL;DR

This work formalizes model distillation through PAC-distillation, establishing a framework to analyze when a simpler model can approximate a trained, larger model under distributional data. It shows distillation can be cheaper than learning from scratch and develops a general theory of computational reductions and statistical bounds, including a novel Linear Representation Hypothesis (LRH) that enables distillation of networks into explicit decision trees in polynomial time. The paper presents two case studies—distilling networks into juntas and into decision trees—demonstrating both algorithmic feasibility and practical benefits, and provides a web of reductions to relate distillation across model classes. It also discusses robust statistical results, highlighting perfect vs. agnostic distillation, Pareto-frontier bounds, and limitations such as the non-characterization of agnostic-distillation sample complexity, while outlining extensions toward broader model classes and foundation models. Overall, the work lays foundational theory and practical algorithms for distillation with potential impact on interpretability, efficiency, and resource-sharing in machine learning systems.

Abstract

Distillation is the task of replacing a complicated machine learning model with a simpler model that approximates the original [BCNM06,HVD15]. Despite many practical applications, basic questions about the extent to which models can be distilled, and the runtime and amount of data needed to distill, remain largely open. To study these questions, we initiate a general theory of distillation, defining PAC-distillation in an analogous way to PAC-learning [Val84]. As applications of this theory: (1) we propose new algorithms to extract the knowledge stored in the trained weights of neural networks -- we show how to efficiently distill neural networks into succinct, explicit decision tree representations when possible by using the ``linear representation hypothesis''; and (2) we prove that distillation can be much cheaper than learning from scratch, and make progress on characterizing its complexity.
Paper Structure (49 sections, 19 theorems, 82 equations, 6 figures, 1 algorithm)

This paper contains 49 sections, 19 theorems, 82 equations, 6 figures, 1 algorithm.

Key Result

Proposition 3.2

Suppose that each network evaluation takes $T$ time. Then, for any $0 < \delta < 1$, we can $(\epsilon=0,\delta)$-distill ${\mathcal{F}}_{{\sf NN}, k\hbox{-}\mathrm{juntas}}$ into ${\mathcal{F}}_{k\hbox{-}\mathrm{juntas}}$ in $0$ samples and $O(T \cdot k \log d + T \cdot 2^k \log\frac{1}{\delta}) +\

Figures (6)

  • Figure 1: Summary of PAC-learning vs. PAC-distillation as defined in Section \ref{['sec:def-pac-dist']}. The difference is that in PAC-distillation, the source model $f$ is one of the inputs to the distillation algorithm. For example, if the goal is to distill a neural network, then the neural network's weights are inputted to the distillation algorithm.
  • Figure 2: Example of a depth-3 decision tree with 11 nodes and $d = 100$.
  • Figure 3: (a) Example of a depth-3 decision tree with 11 nodes and $d = 100$. (b) List of all intermediate computations in this tree; there is one $\mathrm{AND}$ function per path starting at the root. Note that these paths do not have to end at the leaves. For convenience, we also include the path of length zero, which is why $\mathrm{AND}_{\emptyset} \equiv 1$ is also one of the intermediate computations.
  • Figure 4: Example random trees of varying depths on which we benchmark our distillation algorithm. We train a 5-layer ResNet network to learn the tree based on 100,000 random samples for the depth-2,3,4 trees and 1,000,000 random samples for the depth-5 trees. Given the trained network, we then recover from the trained network using the distillation procedure. The input space is $\{0,1\}^d$ with input dimension $d = 100$. See Figure \ref{['fig:number-linear-probes']} for numerical details.
  • Figure 5: For each tree depth in $\{2,3,4,5\}$ we generate 5 random decision trees, train a depth-5 ResNet on each one with the cross-entropy loss, and distill to a decision tree. We report the results when we vary the hyperparameter $k$, which controls the size to which we prune the set of probes as we explore the space of functions efficiently represented by the network. In the third column, we report the accuracy of the distilled decision tree, which increases as the hyperparameter $k$ increases and the number of probes that the algorithm can execute increases. The average number of probes with a given depth and $k$ is reported in the fourth column, and is accurate across runs up to $\pm 0.5\%$ accuracy. The final column compares this average number of probes to the total number of possible probes of AND functions that the algorithm could make if it were brute-forcing. For depth $r$, this total number of possible probes is $\sum_{i=0}^r 2^i \binom{d}{i}$. We see in the final column that our algorithm requires only a very small fraction of the brute-force number of probes to succeed in recovering the true tree, supporting the linear representation hypothesis for networks trained on random decision trees. We leave more in-depth experiments beyond synthetic settings to future work.
  • ...and 1 more figures

Theorems & Definitions (64)

  • Remark 1.2: Relation to PAC-learning
  • Remark 1.3: Relation to the KLCL model of ben2011learning
  • Remark 1.6: Comparison to query learning, and learning from random examples
  • Remark 1.7: Beyond decision trees
  • Definition 2.1: PAC-learning; valiant1984theory
  • Definition 2.2: PAC-distillation
  • Definition 2.3: Agnostic PAC-distillation
  • Definition 3.1: Junta
  • Proposition 3.2: Efficient distillation of networks into juntas; cf. bshouty2016exact
  • proof
  • ...and 54 more