Table of Contents
Fetching ...

Transfer Learning for Text Diffusion Models

Kehang Han, Kathleen Kenealy, Aditya Barua, Noah Fiedel, Noah Constant

TL;DR

The paper investigates replacing autoregressive decoding with text diffusion for training and deploying large language models, introducing AR2Diff to adapt pretrained AR checkpoints into diffusion models. Through systematic comparisons of architectures and objectives, decoder-only diffusion with a prefix LM objective emerges as the most effective setup for transfer learning, and AR2Diff further improves diffusion quality while enabling faster long-form generation. On WMT translation, diffusion lags AR, but on code synthesis and extractive QA diffusion-based models outperform AR baselines, with additional gains when applying AR2Diff. These results demonstrate the practical potential of lightweight diffusion-based transfer learning to accelerate long text generation at scale, while highlighting task-dependent performance trade-offs.

Abstract

In this report, we explore the potential for text diffusion to replace autoregressive (AR) decoding for the training and deployment of large language models (LLMs). We are particularly interested to see whether pretrained AR models can be transformed into text diffusion models through a lightweight adaptation procedure we call ``AR2Diff''. We begin by establishing a strong baseline setup for training text diffusion models. Comparing across multiple architectures and pretraining objectives, we find that training a decoder-only model with a prefix LM objective is best or near-best across several tasks. Building on this finding, we test various transfer learning setups for text diffusion models. On machine translation, we find that text diffusion underperforms the standard AR approach. However, on code synthesis and extractive QA, we find diffusion models trained from scratch outperform AR models in many cases. We also observe quality gains from AR2Diff -- adapting AR models to use diffusion decoding. These results are promising given that text diffusion is relatively underexplored and can be significantly faster than AR decoding for long text generation.

Transfer Learning for Text Diffusion Models

TL;DR

The paper investigates replacing autoregressive decoding with text diffusion for training and deploying large language models, introducing AR2Diff to adapt pretrained AR checkpoints into diffusion models. Through systematic comparisons of architectures and objectives, decoder-only diffusion with a prefix LM objective emerges as the most effective setup for transfer learning, and AR2Diff further improves diffusion quality while enabling faster long-form generation. On WMT translation, diffusion lags AR, but on code synthesis and extractive QA diffusion-based models outperform AR baselines, with additional gains when applying AR2Diff. These results demonstrate the practical potential of lightweight diffusion-based transfer learning to accelerate long text generation at scale, while highlighting task-dependent performance trade-offs.

Abstract

In this report, we explore the potential for text diffusion to replace autoregressive (AR) decoding for the training and deployment of large language models (LLMs). We are particularly interested to see whether pretrained AR models can be transformed into text diffusion models through a lightweight adaptation procedure we call ``AR2Diff''. We begin by establishing a strong baseline setup for training text diffusion models. Comparing across multiple architectures and pretraining objectives, we find that training a decoder-only model with a prefix LM objective is best or near-best across several tasks. Building on this finding, we test various transfer learning setups for text diffusion models. On machine translation, we find that text diffusion underperforms the standard AR approach. However, on code synthesis and extractive QA, we find diffusion models trained from scratch outperform AR models in many cases. We also observe quality gains from AR2Diff -- adapting AR models to use diffusion decoding. These results are promising given that text diffusion is relatively underexplored and can be significantly faster than AR decoding for long text generation.
Paper Structure (11 sections, 3 figures, 3 tables)

This paper contains 11 sections, 3 figures, 3 tables.

Figures (3)

  • Figure 1: Pretraining objectives and model architectures. The <X> and <Y> symbols are unique sentinel tokens denoting masked spans. Note, the "masking noise" applied to produce the span corruption input/target is independent from the "diffusion noise" which randomly corrupts a subset of target tokens. Loss is only computed over target tokens. In the decoder-only setting, input tokens are frozen when computing the unrolled logits input ($l_2$).
  • Figure 2: Illustration of our AR2Diff method. 1) Pretrain an AR decoder with causal attention. 2) Continue pretraining as a diffusion model with bidirectional attention. 3) Fine-tune as a diffusion model on the end task.
  • Figure 3: By varying the decoding sequence length, we measure inference time of autoregressive decoding vs. diffusion decoding