CoT2Align: Cross-Chain of Thought Distillation via Optimal Transport Alignment for Language Models with Different Tokenizers
Anh Duc Le, Tu Vu, Nam Le Hai, Nguyen Thi Ngoc Diep, Linh Ngo Van, Trung Le, Thien Huu Nguyen
TL;DR
CoT2Align tackles knowledge distillation across language models with different tokenizers by introducing reasoning-aware distillation via Chain-of-Thought augmentation and Cross-CoT Alignment. It extends Optimal Transport from token-level to sequence- and layer-level alignment, enabling effective transfer of multi-step reasoning regardless of vocabulary differences. The method defines two alignment losses and combines them into a unified objective that also incorporates the standard KD loss, yielding superior performance across diverse domain-specific datasets and model scales. Empirical results show consistent improvements over state-of-the-art baselines, demonstrating the practicality and robustness of reasoning-aware distillation for real-world, vocabulary-diverse LLM deployment.
Abstract
Large Language Models (LLMs) achieve state-of-the-art performance across various NLP tasks but face deployment challenges due to high computational costs and memory constraints. Knowledge distillation (KD) is a promising solution, transferring knowledge from large teacher models to smaller student models. However, existing KD methods often assume shared vocabularies and tokenizers, limiting their flexibility. While approaches like Universal Logit Distillation (ULD) and Dual-Space Knowledge Distillation (DSKD) address vocabulary mismatches, they overlook the critical \textbf{reasoning-aware distillation} aspect. To bridge this gap, we propose CoT2Align a universal KD framework that integrates Chain-of-Thought (CoT) augmentation and introduces Cross-CoT Alignment to enhance reasoning transfer. Additionally, we extend Optimal Transport beyond token-wise alignment to a sequence-level and layer-wise alignment approach that adapts to varying sequence lengths while preserving contextual integrity. Comprehensive experiments demonstrate that CoT2Align outperforms existing KD methods across different vocabulary settings, improving reasoning capabilities and robustness in domain-specific tasks.
