Table of Contents
Fetching ...

Model Parallelism With Subnetwork Data Parallelism

Vaibhav Singh, Zafir Khalid, Edouard Oyallon, Eugene Belilovsky

TL;DR

This work tackles the memory bottlenecks of pretraining large neural networks by introducing Subnetwork Data Parallelism (SDP), which partitions a model into structurally complete subnetworks trained across workers without exchanging activations. SDP supports two masking regimes—forward masking and backward masking—and two subnetwork constructions—Neuron-Level SDP and Block-Level SDP—providing theoretical guarantees (backward masking convergence under $L$-smoothness and a spectral-gap deviation bound) and concrete memory/communication cost reductions. Empirically, SDP achieves 30%-75% per-device memory savings while maintaining or even improving accuracy across CNNs, vision transformers, and LLM pre-training, with forward masking sometimes outperforming the dense data-parallel baseline in FLOP-matched settings. The results demonstrate SDP as a practical approach to scaling model capacity under fixed hardware constraints, enabling training of larger models with reduced communication and memory overhead.

Abstract

Pre-training large neural networks at scale imposes heavy memory demands on accelerators and often requires costly communication. We introduce Subnetwork Data Parallelism (SDP), a distributed training framework that partitions a model into structured subnetworks trained across workers without exchanging activations. We study two complementary masking regimes: backward masking, which applies sparsity only in the backward step to retain unbiased gradients, and forward masking, which also removes parameters in the forward pass to deliver stronger efficiency gains while providing additional regularization. We further explore two subnetwork construction strategies: neuron level and block level, applied across both CNNs and transformers. In experiments spanning CNNs and transformers on CIFAR and ImageNet, as well as LLM pre-training on FineWeb, SDP reduces per-device memory usage by 30%-75% while maintaining or improving performance. Notably, in FLOP-matched settings, forward masking can sometimes achieve better performance.

Model Parallelism With Subnetwork Data Parallelism

TL;DR

This work tackles the memory bottlenecks of pretraining large neural networks by introducing Subnetwork Data Parallelism (SDP), which partitions a model into structurally complete subnetworks trained across workers without exchanging activations. SDP supports two masking regimes—forward masking and backward masking—and two subnetwork constructions—Neuron-Level SDP and Block-Level SDP—providing theoretical guarantees (backward masking convergence under -smoothness and a spectral-gap deviation bound) and concrete memory/communication cost reductions. Empirically, SDP achieves 30%-75% per-device memory savings while maintaining or even improving accuracy across CNNs, vision transformers, and LLM pre-training, with forward masking sometimes outperforming the dense data-parallel baseline in FLOP-matched settings. The results demonstrate SDP as a practical approach to scaling model capacity under fixed hardware constraints, enabling training of larger models with reduced communication and memory overhead.

Abstract

Pre-training large neural networks at scale imposes heavy memory demands on accelerators and often requires costly communication. We introduce Subnetwork Data Parallelism (SDP), a distributed training framework that partitions a model into structured subnetworks trained across workers without exchanging activations. We study two complementary masking regimes: backward masking, which applies sparsity only in the backward step to retain unbiased gradients, and forward masking, which also removes parameters in the forward pass to deliver stronger efficiency gains while providing additional regularization. We further explore two subnetwork construction strategies: neuron level and block level, applied across both CNNs and transformers. In experiments spanning CNNs and transformers on CIFAR and ImageNet, as well as LLM pre-training on FineWeb, SDP reduces per-device memory usage by 30%-75% while maintaining or improving performance. Notably, in FLOP-matched settings, forward masking can sometimes achieve better performance.

Paper Structure

This paper contains 28 sections, 4 theorems, 26 equations, 2 figures, 5 tables.

Key Result

Proposition 1

Let $\rho \geq 0$ be the spectral gap defined above. Then for any collection of vectors $\mathbf{g}_1,\ldots,\mathbf{g}_n \in \mathbb{R}^{|J|}$,

Figures (2)

  • Figure 1: Data Parallelism (DDP) vs. Subnetwork Data Parallelism (SDP).Left: In data parallelism each GPU hosts a full replica, computes all layer gradients $\{\nabla L_1,\nabla L_2,\nabla L_3,\nabla L_4\}$, and all-reduces all parameters each step; per-GPU memory is approximately the full model (parameters $+$ gradients $+$ optimizer state $+$ activations). Right: In SDP each GPU trains an end-to-end subnetwork (a subset of layers/neurons) with a local loss $L_k(\theta)$; only gradients of shared parameters are synchronized via masked averaging (dashed arcs). For a coverage ratio$\mathcal{C}$$=p/n$ (each parameter resides on $p$ of $n$ GPUs), both memory and communication per GPU scale as $\approx$$\mathcal{C}$$\times$ DP, with no cross-GPU activation exchange. This enables fitting larger models or longer sequences under the same hardware budget and improves scalability when bandwidth or memory are bottlenecks; when $\mathcal{C}$$=1$ (all parameters on all GPUs), SDP reduces to standard DDP.
  • Figure 2: Cosine similarity between subnetworks with N-SDP, B-SDP, B$_{b}$-SDP and a full ResNet-18 model's gradients, across various convolutional layers. The subnetworks constructed above have a coverage ratio ($\mathcal{C}$$=4/8$) for N-SDP and same active blocks ($\mathcal{A}$$=4/8$) for B-SDP and B$_{b}$-SDP.

Theorems & Definitions (6)

  • Proposition 1: Deviation bound under backward masking
  • Theorem 1: SGD rate under Backward-masking
  • Proposition 2: Deviation bound under masking
  • proof
  • Theorem 2: Nonconvex rate under BM
  • proof : Proof (two steps)