Table of Contents
Fetching ...

TuneTables: Context Optimization for Scalable Prior-Data Fitted Networks

Benjamin Feuer, Robin Tibor Schirrmeister, Valeriia Cherepanova, Chinmay Hegde, Frank Hutter, Micah Goldblum, Niv Cohen, Colin White

TL;DR

TuneTables is introduced, a parameter-efficient fine-tuning strategy for PFNs that compresses large datasets into a smaller learned context and can be used as an interpretability tool and can even be used to mitigate biases by optimizing a fairness objective.

Abstract

While tabular classification has traditionally relied on from-scratch training, a recent breakthrough called prior-data fitted networks (PFNs) challenges this approach. Similar to large language models, PFNs make use of pretraining and in-context learning to achieve strong performance on new tasks in a single forward pass. However, current PFNs have limitations that prohibit their widespread adoption. Notably, TabPFN achieves very strong performance on small tabular datasets but is not designed to make predictions for datasets of size larger than 1000. In this work, we overcome these limitations and substantially improve the performance of PFNs via context optimization. We introduce TuneTables, a parameter-efficient fine-tuning strategy for PFNs that compresses large datasets into a smaller learned context. We conduct extensive experiments on 19 algorithms over 98 datasets and find that TuneTables achieves the best performance on average, outperforming boosted trees such as CatBoost, while optimizing fewer than 5% of TabPFN's parameters. Furthermore, we show that TuneTables can be used as an interpretability tool and can even be used to mitigate biases by optimizing a fairness objective. We open-source our code and raw results at https://github.com/penfever/TuneTables.

TuneTables: Context Optimization for Scalable Prior-Data Fitted Networks

TL;DR

TuneTables is introduced, a parameter-efficient fine-tuning strategy for PFNs that compresses large datasets into a smaller learned context and can be used as an interpretability tool and can even be used to mitigate biases by optimizing a fairness objective.

Abstract

While tabular classification has traditionally relied on from-scratch training, a recent breakthrough called prior-data fitted networks (PFNs) challenges this approach. Similar to large language models, PFNs make use of pretraining and in-context learning to achieve strong performance on new tasks in a single forward pass. However, current PFNs have limitations that prohibit their widespread adoption. Notably, TabPFN achieves very strong performance on small tabular datasets but is not designed to make predictions for datasets of size larger than 1000. In this work, we overcome these limitations and substantially improve the performance of PFNs via context optimization. We introduce TuneTables, a parameter-efficient fine-tuning strategy for PFNs that compresses large datasets into a smaller learned context. We conduct extensive experiments on 19 algorithms over 98 datasets and find that TuneTables achieves the best performance on average, outperforming boosted trees such as CatBoost, while optimizing fewer than 5% of TabPFN's parameters. Furthermore, we show that TuneTables can be used as an interpretability tool and can even be used to mitigate biases by optimizing a fairness objective. We open-source our code and raw results at https://github.com/penfever/TuneTables.
Paper Structure (57 sections, 8 equations, 9 figures, 16 tables)

This paper contains 57 sections, 8 equations, 9 figures, 16 tables.

Figures (9)

  • Figure 1: TuneTables: a novel prompt-tuning technique for prior-data fitted networks. TuneTables performs prompt tuning on a pre-trained prior-fitted network (TabPFN) to distill real-world datasets into learned embeddings, allowing for stronger performance and faster inference time than TabPFN in many cases. TuneTables also expands the capabilities of pre-trained PFNs; by way of example, we demonstrate its effectiveness for bias mitigation, and as an interpretability tool.
  • Figure 2: TuneTables and state-of-the-art tabular models. A critical difference plot according to mean accuracy rank across the 98 datasets in Table 1 of mcelfresh2023neural. Algorithms which are not significantly different ($p>0.05$) are connected with a horizontal black bar. TuneTables achieves the highest mean rank of any algorithm.
  • Figure 3: TuneTables addresses TabPFN's limitations. (Left) Motivating example (using the subset of mcelfresh2023neuralon which both CatBoost and TabPFNs3000 report results): TabPFNs3000 is best on small datasets, but when scaled past 3000 datapoints and 100 features, TabPFNs3000 significantly underperforms. (Middle) CatBoost vs. TuneTables on LargeScaleTables : By contrast, TuneTables is competitive with CatBoost on all datasets, mitigating the limitations of TabPFN. (Right) TabPFNs3000 vs. TuneTables on LargeScaleTables : TuneTables outperforms TabPFNs3000 on datasets with a high number of datapoints or features. The colorbar on the y axis represents the comparative change in per-dataset accuracy between two algorithms (A: blue, B: red). Positive numbers represent the absolute gain in accuracy of B w.r.t. A, negative numbers represent the absolute gain in accuracy of A w.r.t. B.
  • Figure 4: Dataset with high accuracies from just two datapoints. Shown is a two-example prompt dataset for the breast cancer dataset misc_breast_cancer_wisconsin_(diagnostic)_17. Malign class example has higher values for all features than benign class.
  • Figure 5: TuneTables is competitive with state-of-the-art tabular models. A critical difference plot according to mean accuracy rank across all LargeScaleTables datasets with fewer than 50 000 samples. Algorithms which are not significantly different ($p>0.05$) are connected with a horizontal black bar. TuneTables achieves the highest mean rank of any algorithm. This plot is similar to \ref{['fig:critical_difference']}, but the search spaces for XGBoost and CatBoost are expanded to include more trees.
  • ...and 4 more figures