Table of Contents
Fetching ...

Enforcing convex constraints in Graph Neural Networks

Ahmed Rashwan, Keith Briggs, Chris Budd, Lisa Kreusser

TL;DR

ProjNet addresses the need for outputs that strictly satisfy input-dependent convex constraints in graph-structured data. It integrates a GPU-accelerated Component-Averaged Dykstra (CAD) projection with sparse vector clipping to produce feasible outputs in $C=\bigcap_i C_i$ and support end-to-end differentiability via a surrogate gradient for CAD. The approach achieves strong speed-ups over traditional solvers (e.g., up to two orders of magnitude faster under favorable conditions) while maintaining competitive solution quality across linear programming, non-convex quadratic programs, and transmit power optimization. The framework leverages constrained input graphs and batched processing to scale to large graphs, offering a tunable speed-accuracy trade-off via a penalty parameter $c_h$ and confirming CAD as a viable, scalable projection tool for constrained GNNs.

Abstract

Many machine learning applications require outputs that satisfy complex, dynamic constraints. This task is particularly challenging in Graph Neural Network models due to the variable output sizes of graph-structured data. In this paper, we introduce ProjNet, a Graph Neural Network framework which satisfies input-dependant constraints. ProjNet combines a sparse vector clipping method with the Component-Averaged Dykstra (CAD) algorithm, an iterative scheme for solving the best-approximation problem. We establish a convergence result for CAD and develop a GPU-accelerated implementation capable of handling large-scale inputs efficiently. To enable end-to-end training, we introduce a surrogate gradient for CAD that is both computationally efficient and better suited for optimization than the exact gradient. We validate ProjNet on four classes of constrained optimisation problems: linear programming, two classes of non-convex quadratic programs, and radio transmit power optimization, demonstrating its effectiveness across diverse problem settings.

Enforcing convex constraints in Graph Neural Networks

TL;DR

ProjNet addresses the need for outputs that strictly satisfy input-dependent convex constraints in graph-structured data. It integrates a GPU-accelerated Component-Averaged Dykstra (CAD) projection with sparse vector clipping to produce feasible outputs in and support end-to-end differentiability via a surrogate gradient for CAD. The approach achieves strong speed-ups over traditional solvers (e.g., up to two orders of magnitude faster under favorable conditions) while maintaining competitive solution quality across linear programming, non-convex quadratic programs, and transmit power optimization. The framework leverages constrained input graphs and batched processing to scale to large graphs, offering a tunable speed-accuracy trade-off via a penalty parameter and confirming CAD as a viable, scalable projection tool for constrained GNNs.

Abstract

Many machine learning applications require outputs that satisfy complex, dynamic constraints. This task is particularly challenging in Graph Neural Network models due to the variable output sizes of graph-structured data. In this paper, we introduce ProjNet, a Graph Neural Network framework which satisfies input-dependant constraints. ProjNet combines a sparse vector clipping method with the Component-Averaged Dykstra (CAD) algorithm, an iterative scheme for solving the best-approximation problem. We establish a convergence result for CAD and develop a GPU-accelerated implementation capable of handling large-scale inputs efficiently. To enable end-to-end training, we introduce a surrogate gradient for CAD that is both computationally efficient and better suited for optimization than the exact gradient. We validate ProjNet on four classes of constrained optimisation problems: linear programming, two classes of non-convex quadratic programs, and radio transmit power optimization, demonstrating its effectiveness across diverse problem settings.

Paper Structure

This paper contains 44 sections, 2 theorems, 29 equations, 6 figures, 1 table, 1 algorithm.

Key Result

Theorem 1

For input $x \in \mathbb{R}^n$, the CAD algorithm cad converges to the projection $P^l_C(x) = \mathop{\mathrm{arg\,min}}\limits_{y \in C}\; \sum_{j=1}^n l_j(y_j - x_j)^2$. In particular, for input point $(x_j / \sqrt{l_j})_{1 \leq j \leq n}$ and feasible set $\{(x_j/ \sqrt{l_j})_{1 \leq j \leq n}:

Figures (6)

  • Figure 1: An illustration of the ProjNet architecture. Shows forward pass on a constraint polygon where points are colour coded with the modules used to compute them. Model takes as input a graph $\mathcal{G}$ with $n$ nodes and a set of linear constraints $C \subset \mathbb R^n$, and outputs a feasible point $y \in C$.
  • Figure 2: Comparing runtimes of CAD algorithm and Gurobi for the linear projection problem as a function of output dimension $n$. $\delta$ values in the legend are a measure of the distance between the initial point and the feasible set. Legend and y-axis are identical for all figures.
  • Figure 3: Comparing runtime of ProjNet, PDLP, and Gurobi for linear programming. We trained three ProjNet models with different $c_h$ values shown in legend. Error bars show upper/lower quartiles for each point. Legend and y-axis are identical for all figures.
  • Figure 4: Comparing runtime of ProjNet, trust-constr, and programming baselines for three classes of optimisation problems. We show two ProjNet models with different values for $c_h$. Precision is set to $\epsilon = 10^{-3}$. Plots include error bars showing upper/lower quartiles for each point. Y-axis is shared.
  • Figure 5: Plots showing method runtime as a function of number of number of variables $n$ for all baseline methods and application problems considered. Error bars denote upper/lower quartiles. Numerical tolerance was set to $\epsilon = 10^{-3}$.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Theorem 1
  • Proposition 1
  • proof : Proof of Theorem \ref{['th:convergence']}
  • proof