Table of Contents
Fetching ...

TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification

Pasan Dissanayake, Sanghamitra Dutta

TL;DR

TabDistill addresses the inefficiency of transformer-based tabular models in few-shot settings by distilling their knowledge into a small neural net. It learns a weight-generating mapping $m_\eta$ that converts transformer encoder outputs $z$ into the parameters $\theta$ of a compact MLP $h_\theta$, enabling a lightweight, differentiable classifier without updating the base model during inference. The approach uses a two-phase process with permutation-based training to combat overfitting on tiny datasets and is instantiated with TabPFN and T0pp as base transformers. Across five datasets, TabDistill outperforms classical baselines in the very few-shot regime and, in some cases, matches or surpasses the original transformer model, offering a scalable and practical option for resource-constrained tabular classification.

Abstract

Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.

TabDistill: Distilling Transformers into Neural Nets for Few-Shot Tabular Classification

TL;DR

TabDistill addresses the inefficiency of transformer-based tabular models in few-shot settings by distilling their knowledge into a small neural net. It learns a weight-generating mapping that converts transformer encoder outputs into the parameters of a compact MLP , enabling a lightweight, differentiable classifier without updating the base model during inference. The approach uses a two-phase process with permutation-based training to combat overfitting on tiny datasets and is instantiated with TabPFN and T0pp as base transformers. Across five datasets, TabDistill outperforms classical baselines in the very few-shot regime and, in some cases, matches or surpasses the original transformer model, offering a scalable and practical option for resource-constrained tabular classification.

Abstract

Transformer-based models have shown promising performance on tabular data compared to their classical counterparts such as neural networks and Gradient Boosted Decision Trees (GBDTs) in scenarios with limited training data. They utilize their pre-trained knowledge to adapt to new domains, achieving commendable performance with only a few training examples, also called the few-shot regime. However, the performance gain in the few-shot regime comes at the expense of significantly increased complexity and number of parameters. To circumvent this trade-off, we introduce TabDistill, a new strategy to distill the pre-trained knowledge in complex transformer-based models into simpler neural networks for effectively classifying tabular data. Our framework yields the best of both worlds: being parameter-efficient while performing well with limited training data. The distilled neural networks surpass classical baselines such as regular neural networks, XGBoost and logistic regression under equal training data, and in some cases, even the original transformer-based models that they were distilled from.

Paper Structure

This paper contains 11 sections, 2 equations, 4 figures, 6 tables, 1 algorithm.

Figures (4)

  • Figure 1: Comparison of TabLLM and TabDistill frameworks. The tunable parameters which are fine-tuned during training in each framework are depicted in green. The example dataset contains Age and Education as features. The target is to predict whether the Income is $>=50$k or not.
  • Figure 2: TabDistill framework. In Phase 1 (left), the tunable parameters of the transformer model (the linear mapping $m_\eta(z)$) is fine-tuned, as depicted in green. The resultant output $h_\theta$ is depicted in amber. When T0pp is used as the base model $f$, a text serialization $g(x, y)$ is applied as shown in the figure. When TabPFN is used as the base model, $g(x, y)$ becomes the identity function. In Phase 2 (right), the MLP may be further fine-tuned if desired, as depicted in green.
  • Figure 3: SHAP feature attributions. Computed on the Calhousing dataset with TabPFN as the base model $f$. Training set size $N$ is 16. 200 samples were used for computing the beeswarm plots.
  • Figure 4: Weights and Biases sweeps used for optimizing hyperparameters for TabDistill with TabPFN and Calhousing dataset, 64 training examples. Each column (y-axes) represents a tunable hyperparameter and its value. The right-most column represents the loss over a validation set. Each colored line traversing from left to right represents a set of values selected for each hyperparameter (i.e., a single run of the sweep). The color of the line corresponds to the validation loss achieved by this set of hyperparameters.