Table of Contents
Fetching ...

LV-XAttn: Distributed Cross-Attention for Long Visual Inputs in Multimodal Large Language Models

Tzu-Tao Chang, Shivaram Venkataraman

TL;DR

LV-XAttn tackles the memory and communication bottlenecks of cross-attention in multimodal LLMs when handling long visual inputs. It introduces a distributed, exact cross-attention mechanism that keeps large key-value blocks locally on each worker while sharing small query blocks, dramatically reducing inter-node communication and enabling computation to hide communication time. An activation recomputation strategy further eases memory pressure by reusing a single copy of visual tokens across cross-attention layers and recomputing keys and values during backpropagation, with modest overhead. Empirical results across Llama 3-V, mPLUG-Owl3, and OpenFlamingo show end-to-end speedups up to $10.62\times$ over Ring Attention and substantial improvements in processing longer visual contexts, demonstrating practical gains for scalable multimodal inference and training.

Abstract

Cross-attention is commonly adopted in multimodal large language models (MLLMs) for integrating visual information into the language backbone. However, in applications with large visual inputs, such as video understanding, processing a large number of visual tokens in cross-attention layers leads to high memory demands and often necessitates distributed computation across multiple GPUs. Existing distributed attention mechanisms face significant communication overheads, making cross-attention layers a critical bottleneck for efficient training and inference of MLLMs. To address this, we propose LV-XAttn, a distributed, exact cross-attention mechanism with minimal communication overhead. We observe that in applications involving large visual inputs, the size of the query block is typically much smaller than that of the key-value blocks. Thus, in LV-XAttn we keep the large key-value blocks locally on each GPU and exchange smaller query blocks across GPUs. We also introduce an efficient activation recomputation technique to support longer visual context. We theoretically analyze the communication benefits of LV-XAttn and show that it can achieve speedups for a wide range of models. Our evaluations with Llama 3-V, mPLUG-Owl3 and OpenFlamingo models find that LV-XAttn achieves up to 10.62$\times$ end-to-end speedup compared to existing approaches.

LV-XAttn: Distributed Cross-Attention for Long Visual Inputs in Multimodal Large Language Models

TL;DR

LV-XAttn tackles the memory and communication bottlenecks of cross-attention in multimodal LLMs when handling long visual inputs. It introduces a distributed, exact cross-attention mechanism that keeps large key-value blocks locally on each worker while sharing small query blocks, dramatically reducing inter-node communication and enabling computation to hide communication time. An activation recomputation strategy further eases memory pressure by reusing a single copy of visual tokens across cross-attention layers and recomputing keys and values during backpropagation, with modest overhead. Empirical results across Llama 3-V, mPLUG-Owl3, and OpenFlamingo show end-to-end speedups up to over Ring Attention and substantial improvements in processing longer visual contexts, demonstrating practical gains for scalable multimodal inference and training.

Abstract

Cross-attention is commonly adopted in multimodal large language models (MLLMs) for integrating visual information into the language backbone. However, in applications with large visual inputs, such as video understanding, processing a large number of visual tokens in cross-attention layers leads to high memory demands and often necessitates distributed computation across multiple GPUs. Existing distributed attention mechanisms face significant communication overheads, making cross-attention layers a critical bottleneck for efficient training and inference of MLLMs. To address this, we propose LV-XAttn, a distributed, exact cross-attention mechanism with minimal communication overhead. We observe that in applications involving large visual inputs, the size of the query block is typically much smaller than that of the key-value blocks. Thus, in LV-XAttn we keep the large key-value blocks locally on each GPU and exchange smaller query blocks across GPUs. We also introduce an efficient activation recomputation technique to support longer visual context. We theoretically analyze the communication benefits of LV-XAttn and show that it can achieve speedups for a wide range of models. Our evaluations with Llama 3-V, mPLUG-Owl3 and OpenFlamingo models find that LV-XAttn achieves up to 10.62 end-to-end speedup compared to existing approaches.

Paper Structure

This paper contains 14 sections, 4 equations, 7 figures, 6 tables, 1 algorithm.

Figures (7)

  • Figure 1: MLLM with cross-attention.
  • Figure 2: Runtime breakdown for a single iteration of Llama 3-V, mPLUG-Owl3-7b, and OpenFlamingo-3b using Ring Attention and LV-XAttn on 16 A100 GPUs. LV-XAttn reduces the time spent on cross-attention computation by 96%, 93%, and 53% for the three models, respectively, compared to Ring Attention. "FWD Vision" refers to the forward pass through the vision encoder and the projection layer; "FWD CA" and "BWD CA" refer to the forward and backward passes through the cross-attention layers in the LLM; and "FWD Non-CA" and "BWD Non-CA" refer to the forward and backward passes through the non-cross-attention layers in the LLM. Llama 3-V was evaluated with a text length of 1K and a frame count of 192 $(S_Q = 1\text{K},\ S_{KV} = 1200\text{K})$; mPLUG-Owl3-7b was evaluated with a text length of 4K and a frame count of 2K $(S_Q = 4\text{K},\ S_{KV} = 1458\text{K})$; and OpenFlamingo-3b was evaluated with a text length and a frame count of 32K $(S_Q = 32\text{K}, S_{KV} = 2048\text{K})$.
  • Figure 3: LV-XAttn with 4 workers. We partition the KV blocks and each worker stores their respective large key-value blocks $K_i, V_i$. We also partition the query ($Q_i$), output ($O_i$), and softmax statistics ($m_i$ and $l_i$ omitted in the figure). The query and output are rotated among workers to compute the attention.
  • Figure 4: The theoretical speedup of LV-XAttn over Ring Attention for cross-attention on a 4-node cluster. Each node is equipped with 4 A100 GPUs, and nodes are interconnected by a 25 GB/s network. The markers represent processing a 2,386-second video and a 3,128-word text prompt -- average values for long videos in Video-MME fu2024video-mme -- using LLama-3V, mPLUG-Owl3 and OpenFlamingo models at different frame rates. Note that for each frame, a special token <image> have to be added to the text-prompt, resulting in $S_Q = 2386+3128=5514$.
  • Figure 5: Ablation study on the effect of overlapping communication and computation with 6 A100 40GB GPUs. The frame count is set to 2048 per worker. Since processing the same total number of frames on a single GPU is not feasible due to memory constraints, the "no communication" runtime is derived by running the same per-worker input size on a single GPU and then scaling the result by 6. LV-XAttn incurs an overhead of less than 0.42% compared to the no-communication baseline.
  • ...and 2 more figures