Table of Contents
Fetching ...

Beyond Autoregression: Fast LLMs via Self-Distillation Through Time

Justin Deschenaux, Caglar Gulcehre

TL;DR

This work tackles the latency of autoregressive LLMs by introducing Self-Distillation Through Time (SDTT) for discrete diffusion language models, enabling parallel generation of tokens (e.g., 32 tokens at once) and reducing inference steps by factors of 32–64. SDTT trains a student denoiser to imitate a teacher across shortened decoding trajectories, using iterated rounds and targets generated from the teacher with limited steps, achieving substantialSpeedups (up to 8x real-time speed when comparing to KV-cached AR baselines) while preserving or improving text quality on benchmarks like LAMBADA and MAUVE. The approach scales to models up to 860M parameters, maintains downstream task performance, and demonstrates favorable latency characteristics with no reliance on activation caching. Overall, SDTT provides a practical and scalable path to fast, high-quality diffusion-based language models suitable for tasks requiring multiple completions and planning.

Abstract

Autoregressive (AR) Large Language Models (LLMs) have demonstrated significant success across numerous tasks. However, the AR modeling paradigm presents certain limitations; for instance, contemporary autoregressive LLMs are trained to generate one token at a time, which can result in noticeable latency. Recent advances have indicated that search and repeated sampling can enhance performance in various applications, such as theorem proving, code generation, and alignment, by utilizing greater computational resources during inference. In this study, we demonstrate that diffusion language models are capable of generating at least 32 tokens simultaneously, while exceeding the performance of AR models in text quality and on the LAMBADA natural language understanding benchmark. This outcome is achieved through a novel distillation method for discrete diffusion models, which reduces the number of inference steps by a factor of 32-64. Practically, at the 1.3B parameters scale, diffusion models, even without caching, can generate tokens at a rate that is up to 8 times faster than AR models employing KV-caching, and we anticipate further improvements with the inclusion of caching. Moreover, we demonstrate the efficacy of our approach for diffusion language models with up to 860M parameters.

Beyond Autoregression: Fast LLMs via Self-Distillation Through Time

TL;DR

This work tackles the latency of autoregressive LLMs by introducing Self-Distillation Through Time (SDTT) for discrete diffusion language models, enabling parallel generation of tokens (e.g., 32 tokens at once) and reducing inference steps by factors of 32–64. SDTT trains a student denoiser to imitate a teacher across shortened decoding trajectories, using iterated rounds and targets generated from the teacher with limited steps, achieving substantialSpeedups (up to 8x real-time speed when comparing to KV-cached AR baselines) while preserving or improving text quality on benchmarks like LAMBADA and MAUVE. The approach scales to models up to 860M parameters, maintains downstream task performance, and demonstrates favorable latency characteristics with no reliance on activation caching. Overall, SDTT provides a practical and scalable path to fast, high-quality diffusion-based language models suitable for tasks requiring multiple completions and planning.

Abstract

Autoregressive (AR) Large Language Models (LLMs) have demonstrated significant success across numerous tasks. However, the AR modeling paradigm presents certain limitations; for instance, contemporary autoregressive LLMs are trained to generate one token at a time, which can result in noticeable latency. Recent advances have indicated that search and repeated sampling can enhance performance in various applications, such as theorem proving, code generation, and alignment, by utilizing greater computational resources during inference. In this study, we demonstrate that diffusion language models are capable of generating at least 32 tokens simultaneously, while exceeding the performance of AR models in text quality and on the LAMBADA natural language understanding benchmark. This outcome is achieved through a novel distillation method for discrete diffusion models, which reduces the number of inference steps by a factor of 32-64. Practically, at the 1.3B parameters scale, diffusion models, even without caching, can generate tokens at a rate that is up to 8 times faster than AR models employing KV-caching, and we anticipate further improvements with the inclusion of caching. Moreover, we demonstrate the efficacy of our approach for diffusion language models with up to 860M parameters.

Paper Structure

This paper contains 53 sections, 12 equations, 22 figures, 2 tables, 2 algorithms.

Figures (22)

  • Figure 1: Perplexity versus latency. The diffusion models (169M) use 16, 32, 64, 128 and 256 decoding step.
  • Figure 2: Performance on LAMBADA after multiple rounds of SDTT with different distillation losses. We pre-train with the masked diffusion language modeling objective (MDLM) sahoo2024simpleeffectivemaskeddiffusion and distill with 7 rounds of SDTT. Note that a single word in the LAMBADA data set often consists of multiple tokens. We greedily decode all tokens a single forward pass for the diffusion models and decode autoregressively for the AR models.
  • Figure 3: SDTT. In figure (a), we illustrate how we prepare the distillation targets. In figure (b), we display the generative perplexity of samples after distillation.
  • Figure 4: Sampling step ablations on perplexity. Perplexity of samples after each round of iterated SDTT. (a): Iterated SDTT on a small model trained for 1M step. (b): Scaling SDTT to larger models trained for 400K steps.
  • Figure 5: MAUVE performance of the student after each round of SDTT. The teacher performance is computed using samples generated with 128 decoding steps.
  • ...and 17 more figures