Cascade Reward Sampling for Efficient Decoding-Time Alignment
Bolian Li, Yifan Wang, Anamika Lochab, Ananth Grama, Ruqi Zhang
TL;DR
Decoding-time alignment can be efficient but suffers from wasted computation and excessive reward evaluations. CARDS introduces segment-level rejection sampling guided by an uncertainty-based segmentation, which splits generation into semantically complete pieces and evaluates rewards at the segment level, reducing computations while preserving alignment quality. The method is supported by analysis showing RM accuracy on complete segments and strong correlation between segment-level and full-length rewards, enabling efficient sampling of high-reward prefixes. Empirical results across safety, usefulness, and general utility benchmarks show CARDS achieves ~70% faster decoding and wins on multiple evaluation metrics, demonstrating robust improvements and broad applicability.
Abstract
Aligning large language models (LLMs) with human preferences is essential for their applications. Recently, decoding-time alignment has emerged as an effective plug-and-play technique that avoids fine-tuning model parameters. This approach retains the general utility of pretrained LLMs but often suffers from significant inefficiencies during decoding, primarily due to wasted token generation and excessive reward evaluations. To address these challenges, we introduce Cascade Reward Sampling (CARDS) to resolve both efficiency bottlenecks in decoding-time alignment. Specifically, we develop a segment-level rejection sampling algorithm that minimizes redundant computations of both LLMs and reward models (RMs). Central to CARDS is an uncertainty-based segmentation mechanism, which ensures the accuracy of RMs evaluations on incomplete segments. Furthermore, we provide a detailed analysis of reward scores on segments to elucidate the improved alignment performance. Experimental results demonstrate that CARDS significantly improves decoding efficiency, alignment quality, and general utility compared to existing decoding-time alignment methods, achieving approximately a 70% reduction in decoding time and over 90% win-ties in utility and safety benchmarks.
