Table of Contents
Fetching ...

Blockwise Parallel Transformer for Large Context Models

Hao Liu, Pieter Abbeel

TL;DR

The paper tackles the memory bottleneck in long-context Transformers by introducing Blockwise Parallel Transformer (BPT), which fusion-cores self-attention and feedforward computation in a blockwise fashion to avoid full sequence materialization. By processing input in blocks and fusing FFN with attention, BPT achieves up to 32× longer training sequences than vanilla attention and up to 4× memory savings over prior memory-efficient methods, with competitive throughput. Extensive experiments on OpenWebText and ExoRL demonstrate reduced memory usage, longer context lengths, and improved RL performance when conditioning on multiple trajectories. The work shows that blockwise parallelism, coupled with careful normalization and blockwise FFN, enables scalable training of large-context GPT-like models, with practical impact for NLP and reinforcement learning research. Potential future directions include lower-level optimizations (CUDA/Triton) to push runtime speed further while maintaining memory benefits.

Abstract

Transformers have emerged as the cornerstone of state-of-the-art natural language processing models, showcasing exceptional performance across a wide range of AI applications. However, the memory demands posed by the self-attention mechanism and the large feedforward network in Transformers limit their ability to handle long sequences, thereby creating challenges for tasks involving multiple long sequences or long-term dependencies. We present a distinct approach, Blockwise Parallel Transformer (BPT), that leverages blockwise computation of self-attention and feedforward network fusion to minimize memory costs. By processing longer input sequences while maintaining memory efficiency, BPT enables training sequences 32 times longer than vanilla Transformers and up to 4 times longer than previous memory-efficient methods. Extensive experiments on language modeling and reinforcement learning tasks demonstrate the effectiveness of BPT in reducing memory requirements and improving performance.

Blockwise Parallel Transformer for Large Context Models

TL;DR

The paper tackles the memory bottleneck in long-context Transformers by introducing Blockwise Parallel Transformer (BPT), which fusion-cores self-attention and feedforward computation in a blockwise fashion to avoid full sequence materialization. By processing input in blocks and fusing FFN with attention, BPT achieves up to 32× longer training sequences than vanilla attention and up to 4× memory savings over prior memory-efficient methods, with competitive throughput. Extensive experiments on OpenWebText and ExoRL demonstrate reduced memory usage, longer context lengths, and improved RL performance when conditioning on multiple trajectories. The work shows that blockwise parallelism, coupled with careful normalization and blockwise FFN, enables scalable training of large-context GPT-like models, with practical impact for NLP and reinforcement learning research. Potential future directions include lower-level optimizations (CUDA/Triton) to push runtime speed further while maintaining memory benefits.

Abstract

Transformers have emerged as the cornerstone of state-of-the-art natural language processing models, showcasing exceptional performance across a wide range of AI applications. However, the memory demands posed by the self-attention mechanism and the large feedforward network in Transformers limit their ability to handle long sequences, thereby creating challenges for tasks involving multiple long sequences or long-term dependencies. We present a distinct approach, Blockwise Parallel Transformer (BPT), that leverages blockwise computation of self-attention and feedforward network fusion to minimize memory costs. By processing longer input sequences while maintaining memory efficiency, BPT enables training sequences 32 times longer than vanilla Transformers and up to 4 times longer than previous memory-efficient methods. Extensive experiments on language modeling and reinforcement learning tasks demonstrate the effectiveness of BPT in reducing memory requirements and improving performance.
Paper Structure (17 sections, 5 equations, 3 figures, 6 tables, 1 algorithm)

This paper contains 17 sections, 5 equations, 3 figures, 6 tables, 1 algorithm.

Figures (3)

  • Figure 1: Maximum context length during training time with the GPT model using different methods. Model sizes range from 1B to 70B. Figures (A), (B), and (C) show evaluation using one, eight A100, and 64 TPUv4, respectively, with a single sequence. Our method enables training sequences 32 times longer than vanilla attention-based Transformer vaswani2017attention, and 2 to 4 times longer than FlashAttention dao2022flashattention and Memory Efficient Attention rabe2021self. Section \ref{['sec:memory_cost']} provides a memory cost breakdown.
  • Figure 2: We use the same model architecture as the original Transformer but with a different way of organizing the compute. In the diagram, we explain this by showing that for the bottom first incoming input block, we project it into query; then we iterate over the same input sequence positioned above the bottom row, and project it to key and value. These query, key and value are used to compute self-attention (yellow box), whose output is pass to feedforward network (cyan box), followed by a residual connection. In our proposed approach, this process is then repeated for the other incoming input blocks.
  • Figure 3: Key parts of the implementation of Blockwise Parallel in Jax. The full code is available on Github https://github.com/lhao499/llm_large_context