Table of Contents
Fetching ...

Efficiently Learning Branching Networks for Multitask Algorithmic Reasoning

Dongyue Li, Zhenshuo Zhang, Minxuan Duan, Edgar Dobriban, Hongyang R. Zhang

TL;DR

The paper tackles multitask algorithmic reasoning by introducing AutoBRANE, a branching-network framework that automatically learns task-sharing structures across multiple graph and text-based algorithmic tasks. It achieves this with an efficient layer-wise partitioning mechanism based on gradient-derived task affinities and convex relaxations, plus a theoretically grounded first-order approximation and JL-based dimensionality reduction to avoid retraining. Empirically, AutoBRANE outperforms strong multitask and branching baselines on CLRS, CLRS-Text, GraphQA, GraphWiz, and a large 500-task community-detection dataset, while reducing runtime and memory usage. The work demonstrates that hierarchically clustered task groups emerge in learned branching structures and underscores the practical benefits of structured parameter sharing for complex, step-by-step algorithmic reasoning.

Abstract

Algorithmic reasoning -- the ability to perform step-by-step logical inference -- has become a core benchmark for evaluating reasoning in graph neural networks (GNNs) and large language models (LLMs). Ideally, one would like to design a single model capable of performing well on multiple algorithmic reasoning tasks simultaneously. However, this is challenging when the execution steps of algorithms differ from one another, causing negative interference when they are trained together. We propose branching neural networks, a principled architecture for multitask algorithmic reasoning. Searching for the optimal $k$-ary tree with $L$ layers over $n$ algorithmic tasks is combinatorial, requiring exploration of up to $k^{nL}$ possible structures. We develop AutoBRANE, an efficient algorithm that reduces this search to $O(nL)$ time by solving a convex relaxation at each layer to approximate an optimal task partition. The method clusters tasks using gradient-based affinity scores and can be used on top of any base model, including GNNs and LLMs. We validate AutoBRANE on a broad suite of graph-algorithmic and text-based reasoning benchmarks. We show that gradient features estimate true task performance within 5% error across four GNNs and four LLMs (up to 34B parameters). On the CLRS benchmark, it outperforms the strongest single multitask GNN by 3.7% and the best baseline by 1.2%, while reducing runtime by 48% and memory usage by 26%. The learned branching structures reveal an intuitively reasonable hierarchical clustering of related algorithms. On three text-based graph reasoning benchmarks, AutoBRANE improves over the best non-branching multitask baseline by 3.2%. Finally, on a large graph dataset with 21M edges and 500 tasks, AutoBRANE achieves a 28% accuracy gain over existing multitask and branching architectures, along with a 4.5$\times$ reduction in runtime.

Efficiently Learning Branching Networks for Multitask Algorithmic Reasoning

TL;DR

The paper tackles multitask algorithmic reasoning by introducing AutoBRANE, a branching-network framework that automatically learns task-sharing structures across multiple graph and text-based algorithmic tasks. It achieves this with an efficient layer-wise partitioning mechanism based on gradient-derived task affinities and convex relaxations, plus a theoretically grounded first-order approximation and JL-based dimensionality reduction to avoid retraining. Empirically, AutoBRANE outperforms strong multitask and branching baselines on CLRS, CLRS-Text, GraphQA, GraphWiz, and a large 500-task community-detection dataset, while reducing runtime and memory usage. The work demonstrates that hierarchically clustered task groups emerge in learned branching structures and underscores the practical benefits of structured parameter sharing for complex, step-by-step algorithmic reasoning.

Abstract

Algorithmic reasoning -- the ability to perform step-by-step logical inference -- has become a core benchmark for evaluating reasoning in graph neural networks (GNNs) and large language models (LLMs). Ideally, one would like to design a single model capable of performing well on multiple algorithmic reasoning tasks simultaneously. However, this is challenging when the execution steps of algorithms differ from one another, causing negative interference when they are trained together. We propose branching neural networks, a principled architecture for multitask algorithmic reasoning. Searching for the optimal -ary tree with layers over algorithmic tasks is combinatorial, requiring exploration of up to possible structures. We develop AutoBRANE, an efficient algorithm that reduces this search to time by solving a convex relaxation at each layer to approximate an optimal task partition. The method clusters tasks using gradient-based affinity scores and can be used on top of any base model, including GNNs and LLMs. We validate AutoBRANE on a broad suite of graph-algorithmic and text-based reasoning benchmarks. We show that gradient features estimate true task performance within 5% error across four GNNs and four LLMs (up to 34B parameters). On the CLRS benchmark, it outperforms the strongest single multitask GNN by 3.7% and the best baseline by 1.2%, while reducing runtime by 48% and memory usage by 26%. The learned branching structures reveal an intuitively reasonable hierarchical clustering of related algorithms. On three text-based graph reasoning benchmarks, AutoBRANE improves over the best non-branching multitask baseline by 3.2%. Finally, on a large graph dataset with 21M edges and 500 tasks, AutoBRANE achieves a 28% accuracy gain over existing multitask and branching architectures, along with a 4.5 reduction in runtime.

Paper Structure

This paper contains 40 sections, 1 theorem, 29 equations, 6 figures, 12 tables, 3 algorithms.

Key Result

Proposition 3.3

Let $\mathcal{D}$ be a search space of model weights $W$ whose radius is at most $D$. Suppose the gradient of $f_{W^{(0)}}$ at the initialization $W^{(0)}$ in the training set is at most $G$ in Euclidean norm. Let $T$ be the training set of inputs. For each algorithmic reasoning task $i \in \{1, 2, where Let $\hat{L}_S(f_W)$ be the training loss of a subset of tasks $S \subseteq \{1,2,\dots,n\}$

Figures (6)

  • Figure 1: An illustration of our problem setup and the proposed solution. Left: We study learning graph algorithms such as BFS, Bellman-Ford, and DFS, formulating the prediction of intermediate algorithmic states as a node-labeling classification problem. Center: Given $n$ algorithmic reasoning tasks and a base model (e.g., GNNs or low-rank adapters), we develop an efficient method to construct a branching network that automatically learns parameter-sharing structures across all tasks. Right: Our algorithm identifies a task partition at each layer and then searches for a corresponding tree structure over layers $1, 2, \dots, L$. Overall, the procedure runs in time $O(nL)$---dramatically faster than the naive worst-case of $O(2^{nL})$.
  • Figure 2: We give three examples of algorithmic reasoning tasks: Breadth-first search (BFS), depth-first search (DFS), and Bellman-Ford. The node labels of each intermediate step encode the predecessor of each node. We illustrate the predecessors with arrows. One can see that these algorithms share some intermediate steps, but not all, and the goal of this paper is to automatically identify such similarities.
  • Figure 3: We present three examples of branching GNNs, each designed to learn a pair of algorithms. As shown in Figure \ref{['fig_task_examples']}, all three algorithms share identical node labels in the first step, so the same initial GNN layer applies to all. BFS and Bellman–Ford continue to share encodings in steps 2 and 3, thus reusing the second layer, while DFS branches out. These structures are learned automatically from data: across the training graphs, BFS and Bellman–Ford share 75% of their encodings, whereas DFS overlaps by 30% with each.
  • Figure 4: Illustration of the trade-off between error rate, GPU hours, and memory usage for AutoBRANE compared to existing multitask and branching network baselines. We show the results of using MPNN or edge transformers muller2024towards as the base model. AutoBRANE outperforms a single multitask network by 3.7%, demonstrating the effectiveness of branching networks in leveraging positive task transfer. It also achieves the best overall trade-off, reducing the average error rate by 1.2% compared to the strongest baseline, while using 48% fewer GPU hours and 26% less memory.
  • Figure 5: Illustration of the tree structure of the branching network learned on twelve algorithmic reasoning tasks from the CLRS benchmark. Our algorithm identifies clusters of algorithms that follow similar intermediate steps.
  • ...and 1 more figures

Theorems & Definitions (4)

  • Remark 3.1
  • Remark 3.2
  • Proposition 3.3
  • proof : Proof of Proposition \ref{['prop_error_bound']}