Table of Contents
Fetching ...

PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning

Weihua Hu, Yiwen Yuan, Zecheng Zhang, Akihiro Nitta, Kaidi Cao, Vid Kocijan, Jinu Sunil, Jure Leskovec, Matthias Fey

TL;DR

PyTorch Frame tackles the challenge of learning from multi-modal tabular data by introducing Tensor Frame and a modular encoder–combiner–decoder pipeline that maps raw tables into per-column tensors, builds column embeddings, and refines them through column-wise interactions to produce row representations. It enables easy incorporation of external foundation models for text and image modalities and supports end-to-end learning with Graph Neural Networks via PyG for relational data. Empirically, it demonstrates strong gains in tabular tasks involving text and relational data, while remaining competitive on conventional numerical/categorical datasets. The framework offers a flexible, extensible toolkit to accelerate research and deployment of deep tabular learning in real-world multi-modal and relational settings.

Abstract

We present PyTorch Frame, a PyTorch-based framework for deep learning over multi-modal tabular data. PyTorch Frame makes tabular deep learning easy by providing a PyTorch-based data structure to handle complex tabular data, introducing a model abstraction to enable modular implementation of tabular models, and allowing external foundation models to be incorporated to handle complex columns (e.g., LLMs for text columns). We demonstrate the usefulness of PyTorch Frame by implementing diverse tabular models in a modular way, successfully applying these models to complex multi-modal tabular data, and integrating our framework with PyTorch Geometric, a PyTorch library for Graph Neural Networks (GNNs), to perform end-to-end learning over relational databases.

PyTorch Frame: A Modular Framework for Multi-Modal Tabular Learning

TL;DR

PyTorch Frame tackles the challenge of learning from multi-modal tabular data by introducing Tensor Frame and a modular encoder–combiner–decoder pipeline that maps raw tables into per-column tensors, builds column embeddings, and refines them through column-wise interactions to produce row representations. It enables easy incorporation of external foundation models for text and image modalities and supports end-to-end learning with Graph Neural Networks via PyG for relational data. Empirically, it demonstrates strong gains in tabular tasks involving text and relational data, while remaining competitive on conventional numerical/categorical datasets. The framework offers a flexible, extensible toolkit to accelerate research and deployment of deep tabular learning in real-world multi-modal and relational settings.

Abstract

We present PyTorch Frame, a PyTorch-based framework for deep learning over multi-modal tabular data. PyTorch Frame makes tabular deep learning easy by providing a PyTorch-based data structure to handle complex tabular data, introducing a model abstraction to enable modular implementation of tabular models, and allowing external foundation models to be incorporated to handle complex columns (e.g., LLMs for text columns). We demonstrate the usefulness of PyTorch Frame by implementing diverse tabular models in a modular way, successfully applying these models to complex multi-modal tabular data, and integrating our framework with PyTorch Geometric, a PyTorch library for Graph Neural Networks (GNNs), to perform end-to-end learning over relational databases.
Paper Structure (14 sections, 3 figures, 2 tables)

This paper contains 14 sections, 3 figures, 2 tables.

Figures (3)

  • Figure 1: Overview of PyTorch Frame's architecture, consisting of a (1) Tensor Frame materialization stage, (2) semantic type-wise model encodings, (3) column-wise interaction blocks, and a final (4) readout decoder head.
  • Figure 2: MultiNestedTensor based on compressed ragged tensor layout. Our ragged layout describe tensors of shape $[N, C, \cdot]$, where the size of the last dimension can vary across both rows and columns. Internally, data is stored in an efficient compressed format (val, ptr), where val holds data in a flattened vector and ptr holds cumulated offsets of rows and columns. $T[i,j]$ can be accessed via $\texttt{val[ptr[C\,*\,i\,+\,j]:ptr[C\,*\,i\,+\,j\,+\,1]]}$, which allows for efficient slicing and indexing along the row dimension.
  • Figure 3: Scatter plot comparison between deep tabular models and LightGBM across datasets with only numerical and categorical features. Here each "x" represents a single predictive task, and its position represents the predictive performance of a deep tabular model compared against LightGBM. When "x" lies above (resp. below) the diagonal line, it means the LightGBM outperforms (resp. underperforms) the corresponding deep tabular model on the respective task. Overall, LightGBM is still dominating the existing deep tabular models on the conventional numerical/categorical datasets, although the recent Trompt model chen2023trompt is getting close.