Table of Contents
Fetching ...

Accelerating Direct Preference Optimization with Prefix Sharing

Franklin Wang, Sumanth Hegde

TL;DR

This work introduces prefix sharing for preference tuning, a novel technique that processes chosen and rejected responses as one sequence with a shared prefix that achieves significant improvement in training throughput on popular DPO datasets, without any effect on convergence.

Abstract

Offline paired preference optimization algorithms have become a popular approach for fine-tuning on preference data, outperforming traditional supervised fine-tuning in various tasks. However, traditional implementations often involve redundant computations, especially for tasks with long shared prompts. We introduce prefix sharing for preference tuning, a novel technique that processes chosen and rejected responses as one sequence with a shared prefix. To prevent cross-response contamination, we use a custom block-sparse attention mask. Our method achieves $1.1$-$1.5\times$ improvement in training throughput on popular DPO datasets, without any effect on convergence. When combined with sequence packing, we observe consistent $1.3$-$1.6\times$ speedups, benefiting even datasets with smaller sequence lengths. While we focus on Direct Preference Optimization (DPO), our approach is applicable to other paired preference tuning methods. By enhancing computational efficiency, our work contributes to making preference-based fine-tuning more accessible for a wider range of applications and model sizes. We open-source our code at https://github.com/frankxwang/dpo-prefix-sharing.

Accelerating Direct Preference Optimization with Prefix Sharing

TL;DR

This work introduces prefix sharing for preference tuning, a novel technique that processes chosen and rejected responses as one sequence with a shared prefix that achieves significant improvement in training throughput on popular DPO datasets, without any effect on convergence.

Abstract

Offline paired preference optimization algorithms have become a popular approach for fine-tuning on preference data, outperforming traditional supervised fine-tuning in various tasks. However, traditional implementations often involve redundant computations, especially for tasks with long shared prompts. We introduce prefix sharing for preference tuning, a novel technique that processes chosen and rejected responses as one sequence with a shared prefix. To prevent cross-response contamination, we use a custom block-sparse attention mask. Our method achieves - improvement in training throughput on popular DPO datasets, without any effect on convergence. When combined with sequence packing, we observe consistent - speedups, benefiting even datasets with smaller sequence lengths. While we focus on Direct Preference Optimization (DPO), our approach is applicable to other paired preference tuning methods. By enhancing computational efficiency, our work contributes to making preference-based fine-tuning more accessible for a wider range of applications and model sizes. We open-source our code at https://github.com/frankxwang/dpo-prefix-sharing.

Paper Structure

This paper contains 17 sections, 1 equation, 5 figures, 3 tables, 3 algorithms.

Figures (5)

  • Figure 1: Method overview. Prefix sharing removes redundant computation of the shared prompt prefix by combining the responses into a single sequence and modifying the attention mask to prevent cross-response contamination.
  • Figure 2: Sequence packing with and without prefix-sharing for paired preference inputs, illustrated for two training samples. Without prefix-sharing, a sequence packing implementation will have to treat the chosen and rejected responses, each prefixed by the common prompt, as a single unit and then pack these units together. With prefix sharing, the unit for sequence packing is now the shared prompt with the chosen and rejected response.
  • Figure 3: Microbenchmarking results of the MLP layer for Mistral 7B. Relative speedups of prefix sharing over normal paired data are shown in comparison to the ideal speedup (assuming linear runtime). We see that the MLP layer scales very closely to the ideal speedup and that increasing the prefix length helps push the speedup closer to the ideal for a given prefix to completion ratio.
  • Figure 4: Microbenchmarking results of the self-attention operation only for Mistral 7B. Relative speedups of FlexAttention with prefix sharing over FlashAttention-3 and FlexAttention are shown, along with the ideal speedup (assuming perfect quadratic scaling). We see that for high prefix lengths, FlexAttention with prefix sharing attains nearly ideal speedups over FlexAttention without prefix sharing, but overall it is still slower or similar in speed to FlashAttention-3. Nevertheless, we find in practice that self-attention contributes little to overall training time and thus has minimal impacts.
  • Figure 5: Microbenchmarking results of the full self-attention layer (QKV projection + self-attention) for Mistral 7B. Relative speedups of FlexAttention with prefix sharing over FlashAttention-3 and FlexAttention are shown, along with the ideal speedup (assuming linear runtime). We see that although FlexAttention is slower than FlashAttention-3 for lower ratios between the prefix and completion length, as the ratio grows, FlexAttention with prefix sharing become faster.