Table of Contents
Fetching ...

Flash3D: Super-scaling Point Transformers through Joint Hardware-Geometry Locality

Liyan Chen, Gregory P. Meyer, Zaiwei Zhang, Eric M. Wolff, Paul Vernaza

TL;DR

Flash3D transforms 3D point-cloud backbones by co-designing geometry-aware locality with GPU memory locality via Perfect Spatial Hashing (PSH). It introduces Bucket-and-Swin attention fused with FlashAttention to align with GPU tiling, enabling zero-overhead region shifts and dramatically improved throughput. Empirical results show 2.25x speedups and 2.4x memory efficiency over PTv3 on outdoor semantic segmentation tasks, with strong scalability as attention scopes and model sizes grow under the same compute budget. The work demonstrates that hardware-aware algorithm design can unlock substantial gains in accuracy and efficiency for large-scale point-transformer models, with practical implications for autonomous driving and related 3D perception tasks.

Abstract

Recent efforts recognize the power of scale in 3D learning (e.g. PTv3) and attention mechanisms (e.g. FlashAttention). However, current point cloud backbones fail to holistically unify geometric locality, attention mechanisms, and GPU architectures in one view. In this paper, we introduce Flash3D Transformer, which aligns geometric locality and GPU tiling through a principled locality mechanism based on Perfect Spatial Hashing (PSH). The common alignment with GPU tiling naturally fuses our PSH locality mechanism with FlashAttention at negligible extra cost. This mechanism affords flexible design choices throughout the backbone that result in superior downstream task results. Flash3D outperforms state-of-the-art PTv3 results on benchmark datasets, delivering a 2.25x speed increase and 2.4x memory efficiency boost. This efficiency enables scaling to wider attention scopes and larger models without additional overhead. Such scaling allows Flash3D to achieve even higher task accuracies than PTv3 under the same compute budget.

Flash3D: Super-scaling Point Transformers through Joint Hardware-Geometry Locality

TL;DR

Flash3D transforms 3D point-cloud backbones by co-designing geometry-aware locality with GPU memory locality via Perfect Spatial Hashing (PSH). It introduces Bucket-and-Swin attention fused with FlashAttention to align with GPU tiling, enabling zero-overhead region shifts and dramatically improved throughput. Empirical results show 2.25x speedups and 2.4x memory efficiency over PTv3 on outdoor semantic segmentation tasks, with strong scalability as attention scopes and model sizes grow under the same compute budget. The work demonstrates that hardware-aware algorithm design can unlock substantial gains in accuracy and efficiency for large-scale point-transformer models, with practical implications for autonomous driving and related 3D perception tasks.

Abstract

Recent efforts recognize the power of scale in 3D learning (e.g. PTv3) and attention mechanisms (e.g. FlashAttention). However, current point cloud backbones fail to holistically unify geometric locality, attention mechanisms, and GPU architectures in one view. In this paper, we introduce Flash3D Transformer, which aligns geometric locality and GPU tiling through a principled locality mechanism based on Perfect Spatial Hashing (PSH). The common alignment with GPU tiling naturally fuses our PSH locality mechanism with FlashAttention at negligible extra cost. This mechanism affords flexible design choices throughout the backbone that result in superior downstream task results. Flash3D outperforms state-of-the-art PTv3 results on benchmark datasets, delivering a 2.25x speed increase and 2.4x memory efficiency boost. This efficiency enables scaling to wider attention scopes and larger models without additional overhead. Such scaling allows Flash3D to achieve even higher task accuracies than PTv3 under the same compute budget.

Paper Structure

This paper contains 37 sections, 6 equations, 13 figures, 7 tables, 2 algorithms.

Figures (13)

  • Figure 1: Effectiveness our Flash3D transformer by unifying geometric locality, FlashAttention (FA2), and GPU tiling architecture. Our unified perspective leads to drastically improved speed and scalability of point transformers.
  • Figure 2: High-level schematic overview of PTv3. Numbered rectangles represent locations in memory. Adjacent rectangles are adjacent in memory. Arrows indicate data movement. See Section \ref{['sec:prelim']} for details.
  • Figure 3: High-level schematic overview of Flash3D. Flash3D performs multiple rounds of attention with different neighborhood definitions via our bucket-and-swin approach, which saves an expensive global shuffle in each round. See Section \ref{['sec:method']} for details.
  • Figure 4: Illustration of bucket assignments using four hash functions after rebalancing. Colors of points indicate their bucket assignments. We demonstrate our PSH algorithm on a sample point cloud with 100k points from BuildingNet DBLP:conf/iccv/SelvarajuNLMAAC21.
  • Figure 5: Overall Latencies vs. Input Sizes for Flash3D and PTv3.
  • ...and 8 more figures