Table of Contents
Fetching ...

MultiTab: A Scalable Foundation for Multitask Learning on Tabular Data

Dimitrios Sinodinos, Jack Yi Wei, Narges Armanfard

TL;DR

MultiTab-Net addresses the need for scalable multitask learning on tabular data by introducing a transformer-based architecture that uses a multi-token design and a multitask masked attention mechanism to limit task interference while capturing rich feature interactions. It demonstrates superior multitask gains over existing MTL baselines and single-task transformers across diverse datasets, including large-scale recommendations, census-like data, and physics data. To enable rigorous evaluation of multitask dynamics, the authors also introduce MultiTab-Bench, a synthetic dataset generator that allows fine-grained control over task correlations and relative difficulty for any number of tasks. Collectively, these contributions advance multitask learning for tabular domains and provide a framework for robust, scalable multitask modeling in practical applications.

Abstract

Tabular data is the most abundant data type in the world, powering systems in finance, healthcare, e-commerce, and beyond. As tabular datasets grow and span multiple related targets, there is an increasing need to exploit shared task information for improved multitask generalization. Multitask learning (MTL) has emerged as a powerful way to improve generalization and efficiency, yet most existing work focuses narrowly on large-scale recommendation systems, leaving its potential in broader tabular domains largely underexplored. Also, existing MTL approaches for tabular data predominantly rely on multi-layer perceptron-based backbones, which struggle to capture complex feature interactions and often fail to scale when data is abundant, a limitation that transformer architectures have overcome in other domains. Motivated by this, we introduce MultiTab-Net, the first multitask transformer architecture specifically designed for large tabular data. MultiTab-Net employs a novel multitask masked-attention mechanism that dynamically models feature-feature dependencies while mitigating task competition. Through extensive experiments, we show that MultiTab-Net consistently achieves higher multitask gain than existing MTL architectures and single-task transformers across diverse domains including large-scale recommendation data, census-like socioeconomic data, and physics datasets, spanning a wide range of task counts, task types, and feature modalities. In addition, we contribute MultiTab-Bench, a generalized multitask synthetic dataset generator that enables systematic evaluation of multitask dynamics by tuning task count, task correlations, and relative task complexity. Our code is publicly available at https://github.com/Armanfard-Lab/MultiTab.

MultiTab: A Scalable Foundation for Multitask Learning on Tabular Data

TL;DR

MultiTab-Net addresses the need for scalable multitask learning on tabular data by introducing a transformer-based architecture that uses a multi-token design and a multitask masked attention mechanism to limit task interference while capturing rich feature interactions. It demonstrates superior multitask gains over existing MTL baselines and single-task transformers across diverse datasets, including large-scale recommendations, census-like data, and physics data. To enable rigorous evaluation of multitask dynamics, the authors also introduce MultiTab-Bench, a synthetic dataset generator that allows fine-grained control over task correlations and relative difficulty for any number of tasks. Collectively, these contributions advance multitask learning for tabular domains and provide a framework for robust, scalable multitask modeling in practical applications.

Abstract

Tabular data is the most abundant data type in the world, powering systems in finance, healthcare, e-commerce, and beyond. As tabular datasets grow and span multiple related targets, there is an increasing need to exploit shared task information for improved multitask generalization. Multitask learning (MTL) has emerged as a powerful way to improve generalization and efficiency, yet most existing work focuses narrowly on large-scale recommendation systems, leaving its potential in broader tabular domains largely underexplored. Also, existing MTL approaches for tabular data predominantly rely on multi-layer perceptron-based backbones, which struggle to capture complex feature interactions and often fail to scale when data is abundant, a limitation that transformer architectures have overcome in other domains. Motivated by this, we introduce MultiTab-Net, the first multitask transformer architecture specifically designed for large tabular data. MultiTab-Net employs a novel multitask masked-attention mechanism that dynamically models feature-feature dependencies while mitigating task competition. Through extensive experiments, we show that MultiTab-Net consistently achieves higher multitask gain than existing MTL architectures and single-task transformers across diverse domains including large-scale recommendation data, census-like socioeconomic data, and physics datasets, spanning a wide range of task counts, task types, and feature modalities. In addition, we contribute MultiTab-Bench, a generalized multitask synthetic dataset generator that enables systematic evaluation of multitask dynamics by tuning task count, task correlations, and relative task complexity. Our code is publicly available at https://github.com/Armanfard-Lab/MultiTab.

Paper Structure

This paper contains 33 sections, 11 equations, 6 figures, 12 tables.

Figures (6)

  • Figure 1: Overview of MultiTab-Net, our proposed multitask transformer for tabular data. A sample of $d$ features (categorical and/or numerical) and $t$ task tokens are passed through an embedding network to generate embeddings of size $e$ for each token. Task tokens are appended to the feature tokens, forming a combined input sequence, which is then processed by stacked encoder blocks. Each encoder block consists of inter-feature and inter-sample attention modules, along with feed-forward networks and residual connections. After $N$ encoder blocks, the processed task tokens $T' = \{T'_1, \dots, T'_t\}$ are passed through task-specific multilayer perceptrons (MLP) to generate the final task predictions $T^P = \{T^P_1, \dots, T^P_t\}$.
  • Figure 2: Illustration of the attention mask $M_A$ under different masking schemes. Masked cells are shaded in grey and unmasked in white. In this example, there are two task tokens and six feature tokens. In T $\not\to$ T, task tokens do not attend to other task tokens. Similarly, in F $\not\to$ T, feature tokens do not attend to task tokens. Finally, F $\not\to$ T & T $\not\to$ T combines both schemes.
  • Figure 3: Average pairwise Pearson correlation for two and four tasks using the same polynomial degree for each task label. PD = [1, 1] indicates that two tasks were generated using a polynomial of degree 1, and similarly for other PD values and task counts.
  • Figure 4: Average pairwise Pearson correlation for three tasks using different polynomial degrees for each task label. PD = [2, 3, 4] indicates that tasks $T_1$, $T_2$, and $T_3$ were generated using polynomial degrees 2, 3, and 4, respectively.
  • Figure 5: Average multitask gain ($\Delta_m$) comparison under controlled variations of task properties on the synthetic benchmark. Error bars represent the standard error. We vary pairwise task correlation, task complexity, and task count.
  • ...and 1 more figures