Aligning LLMs with Domain Invariant Reward Models
David Wu, Sanjiban Choudhury
TL;DR
This work tackles the challenge of aligning LLMs in domains lacking human preference data by proposing DIAL, a dual-loss framework that learns domain-invariant reward models through Wasserstein-distance-based domain alignment and a source-domain preference objective. The approach trains a base LM with a domain critic and a reward head to separate domain-specific signals from domain-agnostic reward concepts, enabling transfer from labeled source data to unlabeled target data. Theoretical bounds connect target performance to source performance and domain discrepancy, and extensive experiments demonstrate DIAL's effectiveness across cross-lingual, clean-to-noisy, few-shot-to-full, and simple-to-complex transfers, including analyses of embeddings, data scaling, and RLHF-shift adaptation. The results suggest that domain-invariant reward models can significantly improve scalable alignment of LLMs in resource-poor domains, with practical implications for broad, low-cost deployment of RLHF-based systems.
Abstract
Aligning large language models (LLMs) to human preferences is challenging in domains where preference data is unavailable. We address the problem of learning reward models for such target domains by leveraging feedback collected from simpler source domains, where human preferences are easier to obtain. Our key insight is that, while domains may differ significantly, human preferences convey \emph{domain-agnostic} concepts that can be effectively captured by a reward model. We propose \method, a framework that trains domain-invariant reward models by optimizing a dual loss: a domain loss that minimizes the divergence between source and target distribution, and a source loss that optimizes preferences on the source domain. We show \method is a general approach that we evaluate and analyze across 4 distinct settings: (1) Cross-lingual transfer (accuracy: $0.621 \rightarrow 0.661$), (2) Clean-to-noisy (accuracy: $0.671 \rightarrow 0.703$), (3) Few-shot-to-full transfer (accuracy: $0.845 \rightarrow 0.920$), and (4) Simple-to-complex tasks transfer (correlation: $0.508 \rightarrow 0.556$). Our code, models and data are available at \url{https://github.com/portal-cornell/dial}.
