Table of Contents
Fetching ...

GCondNet: A Novel Method for Improving Neural Networks on Small High-Dimensional Tabular Data

Andrei Margeloiu, Nikola Simidjievski, Pietro Lio, Mateja Jamnik

TL;DR

GCondNet tackles the challenge of learning from small, high-dimensional tabular data by introducing sample-wise multiplex graphs, one per feature, whose latent structure is learned by a Graph Neural Network that conditions the first-layer weights of a predictor network. The first-layer weights are formed as $W^{[1]}_{MLP}=\alpha W_{GNN}+(1-\alpha)W_{scratch}$ with $\alpha$ decaying from $1$ to $0$ over $n_{\alpha}$ steps, enabling a gradual shift from graph-informed initialization to autonomous learning. Graphs are constructed per feature using simple distance-based schemes (KNN with $k=5$ or Sparse Relative Distance) and are used only during training, ensuring test-time prediction relies solely on the trained predictor. Empirically, GCondNet outperforms 14 baselines across 12 biomedical datasets, demonstrates robustness to graph-construction choices, and extends to other architectures like TabTransformer, highlighting its generality and potential as a regularisation mechanism for small-sample, high-dimensional tabular tasks.

Abstract

Neural networks often struggle with high-dimensional but small sample-size tabular datasets. One reason is that current weight initialisation methods assume independence between weights, which can be problematic when there are insufficient samples to estimate the model's parameters accurately. In such small data scenarios, leveraging additional structures can improve the model's performance and training stability. To address this, we propose GCondNet, a general approach to enhance neural networks by leveraging implicit structures present in tabular data. We create a graph between samples for each data dimension, and utilise Graph Neural Networks (GNNs) to extract this implicit structure, and for conditioning the parameters of the first layer of an underlying predictor network. By creating many small graphs, GCondNet exploits the data's high-dimensionality, and thus improves the performance of an underlying predictor network. We demonstrate GCondNet's effectiveness on 12 real-world datasets, where it outperforms 14 standard and state-of-the-art methods. The results show that GCondNet is a versatile framework for injecting graph-regularisation into various types of neural networks, including MLPs and tabular Transformers. Code is available at https://github.com/andreimargeloiu/GCondNet.

GCondNet: A Novel Method for Improving Neural Networks on Small High-Dimensional Tabular Data

TL;DR

GCondNet tackles the challenge of learning from small, high-dimensional tabular data by introducing sample-wise multiplex graphs, one per feature, whose latent structure is learned by a Graph Neural Network that conditions the first-layer weights of a predictor network. The first-layer weights are formed as with decaying from to over steps, enabling a gradual shift from graph-informed initialization to autonomous learning. Graphs are constructed per feature using simple distance-based schemes (KNN with or Sparse Relative Distance) and are used only during training, ensuring test-time prediction relies solely on the trained predictor. Empirically, GCondNet outperforms 14 baselines across 12 biomedical datasets, demonstrates robustness to graph-construction choices, and extends to other architectures like TabTransformer, highlighting its generality and potential as a regularisation mechanism for small-sample, high-dimensional tabular tasks.

Abstract

Neural networks often struggle with high-dimensional but small sample-size tabular datasets. One reason is that current weight initialisation methods assume independence between weights, which can be problematic when there are insufficient samples to estimate the model's parameters accurately. In such small data scenarios, leveraging additional structures can improve the model's performance and training stability. To address this, we propose GCondNet, a general approach to enhance neural networks by leveraging implicit structures present in tabular data. We create a graph between samples for each data dimension, and utilise Graph Neural Networks (GNNs) to extract this implicit structure, and for conditioning the parameters of the first layer of an underlying predictor network. By creating many small graphs, GCondNet exploits the data's high-dimensionality, and thus improves the performance of an underlying predictor network. We demonstrate GCondNet's effectiveness on 12 real-world datasets, where it outperforms 14 standard and state-of-the-art methods. The results show that GCondNet is a versatile framework for injecting graph-regularisation into various types of neural networks, including MLPs and tabular Transformers. Code is available at https://github.com/andreimargeloiu/GCondNet.
Paper Structure (24 sections, 1 equation, 9 figures, 13 tables, 2 algorithms)

This paper contains 24 sections, 1 equation, 9 figures, 13 tables, 2 algorithms.

Figures (9)

  • Figure 1: $\mathsf{GCondNet}$ is a general method for leveraging implicit relationships between samples to improve the performance of any predictor network with a linear first layer, such as a standard MLP, on tabular data. (A) Given a tabular dataset ${\bm{X}} \in {\mathbb{R}}^{N \times D}$, we generate a graph ${\mathcal{G}}_j$ for each feature in the dataset (results in $D$ graphs), with each node representing a sample (totalling $N$ nodes per graph). (B) The resulting graphs are passed through a shared Graph Neural Network (GNN), which extracts graph embeddings ${\bm{w}}^{(j)} \in {\mathbb{R}}^{K}$ from each graph ${\mathcal{G}}_j$. We concatenate the graph embeddings into a matrix ${\bm{W}}_{\text{GNN}} = [{\bm{w}}^{(1)},..., {\bm{w}}^{(D)}]$. (C) We use ${\bm{W}}_{\text{GNN}}$ to parameterise the first layer ${\bm{W}}^{[1]}_{\text{MLP}}$ of the MLP predictor as a convex combination ${\bm{W}}^{[1]}_{\text{MLP}} = \alpha {\bm{W}}_{\text{GNN}} + (1 - \alpha) {\bm{W}}_{\text{scratch}}$, where ${\bm{W}}_{\text{scratch}}$ is initialised to zero.
  • Figure 2: The inductive bias of GCondNet robustly improves performance and cannot be replicated without GNNs. We compute the normalised test balanced accuracy across all 12 datasets and 25 runs and report the relative improvement over a baseline MLP. First, we find that $\mathsf{GCondNet}$ is robust across various graph construction methods and provides consistent improvement over an equivalent MLP. Second, to assess the usefulness of the GNNs, we propose three weight initialisation methods designed to emulate GCondNet's inductive biases but without employing GNNs. The results show that $\mathsf{GCondNet}$ outperforms such methods, highlighting the effectiveness of the GNN-extracted latent structure.
  • Figure 3: GCondNet reduces overfitting. The impact of varying the mixing coefficient $\alpha$ is illustrated through the training and validation loss curves (averaged over 25 runs) on 'toxicity'. We train $\mathsf{GCondNet}$ with linearly decaying$\alpha$, along with modified versions with fixed$\alpha$. Two observations are notable: (i) $\mathsf{GCondNet}$ exhibits less overfitting (evident from the converging validation loss) compared to an MLP ($\alpha = 0$), which overfits at the $4{,}000^{\text{th}}$ iteration; (ii) decaying $\alpha$ enhances the training stability while improving the test-time accuracy by at least 2%.
  • Figure 4: $\mathsf{GCondNet}$ is versatile and can enhance various models beyond MLPs. When applied to TabTransformer, $\mathsf{GCondNet}$ consistently improves performance by up to $14\%$.
  • Figure 5: Improvement of $\mathsf{GCondNet}$ (with an MLP backbone) over an equivalent MLP baseline. This figure shows the relative increase in test balanced accuracy of $\mathsf{GCondNet}$ compared to the MLP baseline, averaged across datasets. $\mathsf{GCondNet}$ consistently enhances performance across various $N/D$ ratios, demonstrating its advantages even on small-dimensional data.
  • ...and 4 more figures