GSPN-2: Efficient Parallel Sequence Modeling
Hongjun Wang, Yitong Jiang, Collin McCarthy, David Wehr, Hanrong Ye, Xinhao Li, Ka Chun Cheung, Wonmin Byeon, Jinwei Gu, Ke Chen, Kai Han, Hongxu Yin, Pavlo Molchanov, Jan Kautz, Sifei Liu
TL;DR
GSPN-2 tackles the costly self-attention in vision transformers by redesigning Generalized Spatial Propagation Networks into a single fused CUDA kernel with compact channel propagation via a proxy space. The approach combines algorithmic innovations (channel compression, 2D channel-parallel propagation) with system-level CUDA optimizations (coalesced memory access, SRAM caching, stream concurrency) to achieve large throughput gains across high-resolution images and text-to-image generation without sacrificing accuracy. Empirical results demonstrate major speedups (up to tens of times) and near-peak hardware utilization on modern GPUs, with competitive ImageNet accuracy and substantially faster SDXL generation. This work establishes GSPN-2 as a practical, scalable solution for efficient global spatial context modeling in vision tasks, especially for high-resolution and multimodal applications.
Abstract
Efficient vision transformer remains a bottleneck for high-resolution images and long-video related real-world applications. Generalized Spatial Propagation Network (GSPN) addresses this by replacing quadratic self-attention with a line-scan propagation scheme, bringing the cost close to linear in the number of rows or columns, while retaining accuracy. Despite this advancement, the existing GSPN implementation still suffers from (i) heavy overhead due to repeatedly launching GPU kernels, (ii) excessive data transfers from global GPU memory, and (iii) redundant computations caused by maintaining separate propagation weights for each channel. We introduce GSPN-2, a joint algorithm-system redesign. In particular, we eliminate thousands of micro-launches from the previous implementation into one single 2D kernel, explicitly pin one warp to each channel slice, and stage the previous column's activations in shared memory. On the model side, we introduce a compact channel propagation strategy that replaces per-channel matrices, trimming parameters, and align naturally with the affinity map used in transformer attention. Experiments demonstrate GSPN-2's effectiveness across image classification and text-to-image synthesis tasks, matching transformer-level accuracy with significantly lower computational cost. GSPN-2 establishes a new efficiency frontier for modeling global spatial context in vision applications through its unique combination of structured matrix transformations and GPU-optimized implementation. Project page: https://whj363636.github.io/GSPN2/
