Table of Contents
Fetching ...

Logits Replay + MoClip: Stabilized, Low-Cost Post-Training with Minimal Forgetting

Suming Qiu, Jing Li, Zhicheng Zhou, Junjie Huang, Linyuan Qiu, Zhijie Sun

TL;DR

This work tackles the forgetting-versus-domain specialization dilemma in LLM post-training by introducing Logits Replay + MoClip, a two-stage, logit-space compression and optimizer-stabilization framework. Stage 0 collects dynamic Top-$K$ subsets of logits to form compact targets, while Stage 1 replays these subsets to compute exact renormalized losses, avoiding full softmax. MoClip complements this with gradient–momentum angle clipping and an arctan2-based update scaling, providing provable stability and robust performance without extra data replay. Empirically, the approach yields domain gains on CT QA and NL2SQL, while preserving general reasoning on MMLU, BBH, GPQA, and MATH, and delivering over 40% training-cost reduction across 4B and 8B LLM variants. This combination offers a scalable, architecture-agnostic route to efficient, low-forgetting LLM adaptation with practical significance for continual deployment.

Abstract

Large language models (LLMs) often face a trade-off in post-training: improvements on specialized domains frequently come at the expense of general capabilities. Existing solutions attempt to mitigate this tension via regularization, selective parameter updates, or data-centric replay, but each imposes significant costs in computation, data access, or adaptability. Recent work has shown that training signals can be compressed to subsets of logits without severe accuracy loss, suggesting a path toward efficient adaptation. However, naive truncation destabilizes optimization and exacerbates forgetting. We introduce Logits Replay + MoClip, a two-stage framework that compresses supervision in the logit space and stabilizes optimization at the update level. In Stage 0, we record dynamic Top-K token subsets that cover a probability threshold, always including the gold label. In Stage 1, we replay these compact subsets to compute exact renormalized losses, avoiding full softmax computation and implicitly regularizing. To ensure stability, we design MoClip, an optimizer that caps gradient-momentum rotation and applies an arctan2-based rescaling of updates. Empirically, our method improves domain performance on Communication Technology (CT) and NL2SQL tasks while mitigating forgetting on general benchmarks (MMLU, BBH, GPQA, MATH), and reduces training cost by over 40%. Together, these contributions offer a scalable, architecture-agnostic path for domain adaptation of LLMs without sacrificing generalization.

Logits Replay + MoClip: Stabilized, Low-Cost Post-Training with Minimal Forgetting

TL;DR

This work tackles the forgetting-versus-domain specialization dilemma in LLM post-training by introducing Logits Replay + MoClip, a two-stage, logit-space compression and optimizer-stabilization framework. Stage 0 collects dynamic Top- subsets of logits to form compact targets, while Stage 1 replays these subsets to compute exact renormalized losses, avoiding full softmax. MoClip complements this with gradient–momentum angle clipping and an arctan2-based update scaling, providing provable stability and robust performance without extra data replay. Empirically, the approach yields domain gains on CT QA and NL2SQL, while preserving general reasoning on MMLU, BBH, GPQA, and MATH, and delivering over 40% training-cost reduction across 4B and 8B LLM variants. This combination offers a scalable, architecture-agnostic route to efficient, low-forgetting LLM adaptation with practical significance for continual deployment.

Abstract

Large language models (LLMs) often face a trade-off in post-training: improvements on specialized domains frequently come at the expense of general capabilities. Existing solutions attempt to mitigate this tension via regularization, selective parameter updates, or data-centric replay, but each imposes significant costs in computation, data access, or adaptability. Recent work has shown that training signals can be compressed to subsets of logits without severe accuracy loss, suggesting a path toward efficient adaptation. However, naive truncation destabilizes optimization and exacerbates forgetting. We introduce Logits Replay + MoClip, a two-stage framework that compresses supervision in the logit space and stabilizes optimization at the update level. In Stage 0, we record dynamic Top-K token subsets that cover a probability threshold, always including the gold label. In Stage 1, we replay these compact subsets to compute exact renormalized losses, avoiding full softmax computation and implicitly regularizing. To ensure stability, we design MoClip, an optimizer that caps gradient-momentum rotation and applies an arctan2-based rescaling of updates. Empirically, our method improves domain performance on Communication Technology (CT) and NL2SQL tasks while mitigating forgetting on general benchmarks (MMLU, BBH, GPQA, MATH), and reduces training cost by over 40%. Together, these contributions offer a scalable, architecture-agnostic path for domain adaptation of LLMs without sacrificing generalization.

Paper Structure

This paper contains 31 sections, 10 theorems, 21 equations, 4 figures, 6 tables, 1 algorithm.

Key Result

Lemma 1

For full softmax-CE, the logit gradient is $g^{\mathrm{full}}_z = p - y$. For restricted, renormalized CE over $S$, Hence the logit-space bias $\Delta g_z := g^S_z - g^{\mathrm{full}}_z$ satisfies

Figures (4)

  • Figure 1: Overview of the Logits Replay + MoClip framework.
  • Figure 2: Domain & general benchmarks on Qwen3-4B/8B. Bars show 4B (solid) and 8B (hatched) across three groups: CT (DataComm, Wireless, CloudCore), NL2SQL (Birds, Spider), and General (MMLU, BBH, GPQA, MATH). Dynamic Top-$K$ + MoClip consistently improves domain scores over AdamW and remains competitive or better on general tasks. Vertical dashed lines separate task groups.
  • Figure 3: Stability (loss variance, gradient-norm CV, spike count) and efficiency (step and epoch time) on Qwen3-4B and Qwen3-8B. Lower is better for stability metrics and time.
  • Figure 4: Ablation overview. (A) Stage-0 position strategy (Random / Last-token / Bucket) on 4B across five tasks: bucket sampling consistently lifts all metrics, especially NL2SQL. (B1) Pareto scatter of Loss std vs. NL2SQL avg; marker size reflects retention. (B2) Birds/Spider (bars, left axis) and MMLU-Pro retention (line, right axis) across $\Delta_{\max}$, with the recommended $[45^\circ,60^\circ]$ shaded. (C) 8B ablation summary: per-method CT Avg (DataComm/Wireless/CloudCore) and NL2SQL Avg (Birds/Spider) as bars; right axis overlays retention (%) and loss variance. Our Dyn. Top-$K$ + MoClip attains the best domain averages with strong retention and lowest variance.

Theorems & Definitions (16)

  • Lemma 1: Logit-space gradient forms
  • proof
  • Proposition 1: Bias magnitude in $\ell_1$ and $\ell_2$
  • proof
  • Remark 1: Exact $\ell_2$ form
  • Proposition 2: Parameter-space bias via Jacobian
  • Corollary 1: Bias control via mass threshold
  • Remark 2: Distributional perspective
  • Lemma 2: Angular cap implies cosine lower bound
  • Remark 3: Intuition
  • ...and 6 more