Table of Contents
Fetching ...

Sequence Length Scaling in Vision Transformers for Scientific Images on Frontier

Aristeidis Tsaris, Chengming Zhang, Xiao Wang, Junqi Yin, Siyan Liu, Moetasim Ashfaq, Ming Fan, Jong Youl Choi, Mohamed Wahib, Dan Lu, Prasanna Balaprakash, Feiyi Wang

TL;DR

This paper tackles scaling Vision Transformers to ultra-long sequences for high-resolution scientific imagery on Frontier, addressing memory and communication bottlenecks. It introduces distributed sequence parallelism using DeepSpeed-Ulysses and Long Sequence Segmentation (LSS), augmented by pipeline and tensor parallelism and Flash Attention v2, to reach up to 1M tokens for models up to 10B parameters. The authors demonstrate a 94% batch scaling efficiency on 2,048 AMD MI250X GPUs and show up to 20% improvements in ERA5 temperature predictions when using longer sequences, including training a transformer with full attention at 188K sequence length. They provide empirical baselines, performance bottleneck analyses, and practical guidelines for deploying ultra-long ViTs on scientific data, highlighting the importance of sequence length alongside model size. The work paves the way for robust, scalable foundation models in Earth system science and other data-intensive domains.

Abstract

Vision Transformers (ViTs) are pivotal for foundational models in scientific imagery, including Earth science applications, due to their capability to process large sequence lengths. While transformers for text has inspired scaling sequence lengths in ViTs, yet adapting these for ViTs introduces unique challenges. We develop distributed sequence parallelism for ViTs, enabling them to handle up to 1M tokens. Our approach, leveraging DeepSpeed-Ulysses and Long-Sequence-Segmentation with model sharding, is the first to apply sequence parallelism in ViT training, achieving a 94% batch scaling efficiency on 2,048 AMD-MI250X GPUs. Evaluating sequence parallelism in ViTs, particularly in models up to 10B parameters, highlighted substantial bottlenecks. We countered these with hybrid sequence, pipeline, tensor parallelism, and flash attention strategies, to scale beyond single GPU memory limits. Our method significantly enhances climate modeling accuracy by 20% in temperature predictions, marking the first training of a transformer model on a full-attention matrix over 188K sequence length.

Sequence Length Scaling in Vision Transformers for Scientific Images on Frontier

TL;DR

This paper tackles scaling Vision Transformers to ultra-long sequences for high-resolution scientific imagery on Frontier, addressing memory and communication bottlenecks. It introduces distributed sequence parallelism using DeepSpeed-Ulysses and Long Sequence Segmentation (LSS), augmented by pipeline and tensor parallelism and Flash Attention v2, to reach up to 1M tokens for models up to 10B parameters. The authors demonstrate a 94% batch scaling efficiency on 2,048 AMD MI250X GPUs and show up to 20% improvements in ERA5 temperature predictions when using longer sequences, including training a transformer with full attention at 188K sequence length. They provide empirical baselines, performance bottleneck analyses, and practical guidelines for deploying ultra-long ViTs on scientific data, highlighting the importance of sequence length alongside model size. The work paves the way for robust, scalable foundation models in Earth system science and other data-intensive domains.

Abstract

Vision Transformers (ViTs) are pivotal for foundational models in scientific imagery, including Earth science applications, due to their capability to process large sequence lengths. While transformers for text has inspired scaling sequence lengths in ViTs, yet adapting these for ViTs introduces unique challenges. We develop distributed sequence parallelism for ViTs, enabling them to handle up to 1M tokens. Our approach, leveraging DeepSpeed-Ulysses and Long-Sequence-Segmentation with model sharding, is the first to apply sequence parallelism in ViT training, achieving a 94% batch scaling efficiency on 2,048 AMD-MI250X GPUs. Evaluating sequence parallelism in ViTs, particularly in models up to 10B parameters, highlighted substantial bottlenecks. We countered these with hybrid sequence, pipeline, tensor parallelism, and flash attention strategies, to scale beyond single GPU memory limits. Our method significantly enhances climate modeling accuracy by 20% in temperature predictions, marking the first training of a transformer model on a full-attention matrix over 188K sequence length.
Paper Structure (23 sections, 2 equations, 10 figures, 1 table)

This paper contains 23 sections, 2 equations, 10 figures, 1 table.

Figures (10)

  • Figure 1: Computation cost in FLOPs by varying the number of tokens (left plot), and the parameter size of the model (right plots). For both plots equations (\ref{['eqn:1']}), and (\ref{['eqn:2']}) were used with a total number of tokens for the ERA5 datasets, i.e. around 350K images
  • Figure 2: Validation accuracy on half-year of the ERA5 dataset of the Z500, T850, T2m, and U10 variables. The effect in the accuracy is shown for three different spatial resolutions: 5.625°, 1.40625°, and 1.0°. A patch size of 4 was used, and so the sequence length for each resolution is 128, 2048, and 4050 respectively.
  • Figure 3: Validation accuracy on half-year of the ERA5 dataset of the Z500, T850, T2m, and U10 variables. The effect in the accuracy is shown for including all 92 variables in the model, Multi-Ch-ViT, versus Agg-Ch-ViT. The sequence length of the first approach is 11,776, while it reduces to 128 for the second approach.
  • Figure 4: Single GPU measurements with and without flash-attention-2, on Frontier MI250X GPU (1 GCD). The top plot shows the GPU throughput for two model sizes, ViT-Base and ViT-Large, while the bottom plot shows the max GPU memory reserved throughout each run. For the flash-attention-2 measurements, Pytorch 2.4 nightly build 03/16/2024 was used with rocm5.7.
  • Figure 5: Weak scale the sequence length for each distributed method separately. Specifically, the top plot shows the DeepSpeed-Ulysses and Pipeline (PP) methods within the DeepSpeed framework, while we also scale the batch size linearly with the DeepSpeed Zero Redundancy Optimizer (DP). The bottom plot shows the Long Short-Sequence method (LSS) and the tensor parallelism (TP) methods within Pytorch's FSDP framework, while we also show scaling linearly the batch size with the fully sharded method (DP). The model for the two implementations was set to a 2560 hidden dimension with 16 heads total.
  • ...and 5 more figures