Table of Contents
Fetching ...

MemDLM: Memory-Enhanced DLM Training

Zehua Pei, Hui-Ling Zhen, Weizhe Lin, Sinno Jialin Pan, Yunhe Wang, Mingxuan Yuan, Bei Yu

Abstract

Diffusion Language Models (DLMs) offer attractive advantages over Auto-Regressive (AR) models, such as full-attention parallel decoding and flexible generation. However, they suffer from a notable train-inference mismatch: DLMs are trained with a static, single-step masked prediction objective, but deployed through a multi-step progressive denoising trajectory. We propose MemDLM (Memory-Enhanced DLM), which narrows this gap by embedding a simulated denoising process into training via Bi-level Optimization. An inner loop updates a set of fast weights, forming a Parametric Memory that captures the local trajectory experience of each sample, while an outer loop updates the base model conditioned on this memory. By offloading memorization pressure from token representations to parameters, MemDLM yields faster convergence and lower training loss. Moreover, the inner loop can be re-enabled at inference time as an adaptation step, yielding additional gains on long-context understanding. We find that, when activated at inference time, this Parametric Memory acts as an emergent in-weight retrieval mechanism, helping MemDLM further reduce token-level attention bottlenecks on challenging Needle-in-a-Haystack retrieval tasks. Code: https://github.com/JarvisPei/MemDLM.

MemDLM: Memory-Enhanced DLM Training

Abstract

Diffusion Language Models (DLMs) offer attractive advantages over Auto-Regressive (AR) models, such as full-attention parallel decoding and flexible generation. However, they suffer from a notable train-inference mismatch: DLMs are trained with a static, single-step masked prediction objective, but deployed through a multi-step progressive denoising trajectory. We propose MemDLM (Memory-Enhanced DLM), which narrows this gap by embedding a simulated denoising process into training via Bi-level Optimization. An inner loop updates a set of fast weights, forming a Parametric Memory that captures the local trajectory experience of each sample, while an outer loop updates the base model conditioned on this memory. By offloading memorization pressure from token representations to parameters, MemDLM yields faster convergence and lower training loss. Moreover, the inner loop can be re-enabled at inference time as an adaptation step, yielding additional gains on long-context understanding. We find that, when activated at inference time, this Parametric Memory acts as an emergent in-weight retrieval mechanism, helping MemDLM further reduce token-level attention bottlenecks on challenging Needle-in-a-Haystack retrieval tasks. Code: https://github.com/JarvisPei/MemDLM.
Paper Structure (26 sections, 8 equations, 11 figures, 3 tables)

This paper contains 26 sections, 8 equations, 11 figures, 3 tables.

Figures (11)

  • Figure 1: Needle-in-a-Haystack results overview. Gray bars denote Standard MDLM and blue bars denote MemDLM. Left: detailed results on RULER-MV, RULER-VT, RULER-CWE, and BABILong for the LLaDA-MoE-7B-A1B-Base and LLaDA2.1-mini backbones. Right: mean absolute improvement of MemDLM over Standard MDLM for each task, averaged across the evaluated context lengths within each backbone.
  • Figure 2: Overview of MemDLM. Left: standard MDLM training uses a static single-step denoising objective from $x_t$ to $x_0$. Right: MemDLM uses Bi-level Optimization in which an inner loop updates fast weights $\phi$ along an anchor-consistent local trajectory ($x_{t_{\textit{pre}}} \rightarrow x_t \rightarrow x_0$), and the outer loop updates the base model $\theta$ on the anchor state $x_t$ conditioned on this parametric memory. Legend: dark tokens denote mask tokens, light tokens denote observed tokens, straight arrows denote forward or reverse prediction flow, and blue curved arrows denote inner-loop fast-weight updates.
  • Figure 3: Exposure Bias Ratio ($\mathcal{R}_{\text{EB}}$) across denoising steps. Standard MDLM degrades rapidly, while MemDLM remains substantially flatter.
  • Figure 4: Training dynamics on the LLaDA-MoE and LLaDA2.1 backbones. We compare Standard MDLM and MemDLM using train loss and evaluation loss. For the train-loss panels, faint curves show the raw logged values and bold curves show a smoothed trend. Across both backbones, MemDLM converges faster and reaches consistently lower train and evaluation loss, supporting the view that memory-aware training improves optimization by reducing the burden of preserving local trajectory information purely in token space.
  • Figure 5: Comparison with the untuned pretrained LLaDA-MoE-7B-A1B-Base model across context lengths.
  • ...and 6 more figures