Table of Contents
Fetching ...

Fold-CP: A Context Parallelism Framework for Biomolecular Modeling

Dejun Lin, Simon Chu, Vishanth Iyer, Youhan Lee, John St John, Kevin Boyd, Brian Roland, Xiaowei Ren, Guoqing Zhou, Zhonglin Cao, Polina Binder, Yuliya Zhautouskaya, Jakub Zakrzewski, Maximilian Stadler, Kyle Gion, Yuxing Peng, Xi Chen, Tianjing Zhang, Philipp Junk, Michelle Dimon, Paweł Gniewek, Fabian Ortega, McKinley Polen, Ivan Grubisic, Ali Bashir, Graham Holt, Danny Kovtun, Matthias Grass, Luca Naef, Rui Wang, Jian Peng, Anthony Costa, Saee Paliwal, Eddie Calleja, Timur Rvachov, Neha Tadimeti, Roy Tal, Emine Kucukbenli

Abstract

Understanding cellular machinery requires atomic-scale reconstruction of large biomolecular assemblies. However, predicting the structures of these systems has been constrained by hardware memory requirements of models like AlphaFold 3, imposing a practical ceiling of a few thousand residues that can be processed on a single GPU. Here we present NVIDIA BioNeMo Fold-CP, a context parallelism framework that overcomes this barrier by distributing the inference and training pipelines of co-folding models across multiple GPUs. We use the Boltz models as open source reference architectures and implement custom multidimensional primitives that efficiently parallelize both the dense triangular updates and the irregular, data-dependent pattern of window-batched local attention. Our approach achieves efficient memory scaling; for an N-token input distributed across P GPUs, per-device memory scales as $O(N^2/P)$, enabling the structure prediction of assemblies exceeding 30,000 residues on 64 NVIDIA B300 GPUs. We demonstrate the scientific utility of this approach through successful developer use cases: Fold-CP enabled the scoring of over 90% of Comprehensive Resource of Mammalian protein complexes (CORUM) database, as well as folding of disease-relevant PI4KA lipid kinase complex bound to an intrinsically disordered region without cropping. By providing a scalable pathway for modeling massive systems with full global context, Fold-CP represents a significant step toward the realization of a virtual cell.

Fold-CP: A Context Parallelism Framework for Biomolecular Modeling

Abstract

Understanding cellular machinery requires atomic-scale reconstruction of large biomolecular assemblies. However, predicting the structures of these systems has been constrained by hardware memory requirements of models like AlphaFold 3, imposing a practical ceiling of a few thousand residues that can be processed on a single GPU. Here we present NVIDIA BioNeMo Fold-CP, a context parallelism framework that overcomes this barrier by distributing the inference and training pipelines of co-folding models across multiple GPUs. We use the Boltz models as open source reference architectures and implement custom multidimensional primitives that efficiently parallelize both the dense triangular updates and the irregular, data-dependent pattern of window-batched local attention. Our approach achieves efficient memory scaling; for an N-token input distributed across P GPUs, per-device memory scales as , enabling the structure prediction of assemblies exceeding 30,000 residues on 64 NVIDIA B300 GPUs. We demonstrate the scientific utility of this approach through successful developer use cases: Fold-CP enabled the scoring of over 90% of Comprehensive Resource of Mammalian protein complexes (CORUM) database, as well as folding of disease-relevant PI4KA lipid kinase complex bound to an intrinsically disordered region without cropping. By providing a scalable pathway for modeling massive systems with full global context, Fold-CP represents a significant step toward the realization of a virtual cell.
Paper Structure (27 sections, 5 equations, 10 figures, 2 tables)

This paper contains 27 sections, 5 equations, 10 figures, 2 tables.

Figures (10)

  • Figure 1: Software Architecture of the Fold-CP Framework. Fold-CP design takes a bottom-up approach starting with foundational torch.distributed APIs, making module-level operators CP-aware, and incrementally building up to model-specific workflows. This design allows developers to port specific distributed modules into diverse model organizations without managing low-level synchronization logic or altering the model flow.
  • Figure 2: Distributed forward-pass flow for the three core ring-communication modules.(a) Triangle Attention: cross-axis transpose of triangular bias followed by a 1D ring rotation of keys, values, and bias with tiled-softmax merge. (b) Triangle Multiplication: Cannon-style 2D ring where projections shift along orthogonal axes with additive accumulation. (c) Ring Attention Pair Bias: 1D ring rotation of keys, values, and pair bias with tiled-softmax merge. Pair Weighted Averaging (Section \ref{['sec:pwa']}) shares the softmax-merge archetype of (a)/(c) over a 2D ring; Outer Product Mean (Section \ref{['sec:opm']}) shares the Cannon accumulation archetype of (b), replacing matrix multiplication with outer products.
  • Figure 3: Boltz-2 CP scalability on NVIDIA B300 GPUs via maximum context length. For a given GPU count $P$, maximum context length reached is plotted. Linear trend lines are used to inform with respect to number of GPUs in a CP rank, $\sqrt{P}$, due to the square topology requirement of Fold-CP. (Left) Inference scaling performance: The maximum context length scales linearly at a rate of approximately 4,000 tokens/$\sqrt{P}$, with runtime annotated accordingly. (Right) Training scaling performance: The training follows a similar scaling with a approximately half the slope of that of inference.
  • Figure 4: Inference consistency validation on Boltz-1 test set. Numerical equivalence between the DP baseline and Fold-CP implementation verified via a Two One-Sided Tests (TOST) procedure on a Wilcoxon signed-rank test with a margin of error ($\epsilon$=0.01). Error bar estimated from standard error of mean (SEM) from 5 diffusion samples for the respective x- (DP, horizontal bar) and y-axis (CP, vertical bar).
  • Figure 5: Training consistency validation. Validation lDDT trajectories for five DP baseline with random seeding and one $\text{CP}=2\times2$ configuration with a model trained on 256-token crop size. The Fold-CP validation lDDT curve largely mirrors the DP references, with variation stemming from stochastic nature of the model. The model used in this validation experiment is a truncated Boltz-1 architecture, hence it is not expected to match the published reference. See text for details.
  • ...and 5 more figures