Contextual Similarity Distillation: Ensemble Uncertainties with a Single Model
Moritz A. Zanger, Pascal R. Van der Vaart, Wendelin Böhmer, Matthijs T. J. Spaan
TL;DR
The paper tackles the challenge of reliable uncertainty quantification in large deep networks, particularly for reinforcement learning, where ensembles are effective but costly. It introduces Contextual Similarity Distillation (CSD), which reinterprets the predictive variance of a random-initialization ensemble as a structured regression problem using NTK-based kernel similarities, enabling a single forward pass to estimate uncertainty. By incorporating a context variable and allowing unlabeled data or augmentations, CSD can produce uncertainty estimates for arbitrary queries with competitive accuracy relative to ensembles and Bayesian baselines, at reduced computational cost. Empirical results on distribution shift detection and VizDoom exploration demonstrate strong performance and tangible benefits for exploration and safety-aware decision-making, positioning CSD as a scalable alternative for uncertainty quantification in deep learning and RL.
Abstract
Uncertainty quantification is a critical aspect of reinforcement learning and deep learning, with numerous applications ranging from efficient exploration and stable offline reinforcement learning to outlier detection in medical diagnostics. The scale of modern neural networks, however, complicates the use of many theoretically well-motivated approaches such as full Bayesian inference. Approximate methods like deep ensembles can provide reliable uncertainty estimates but still remain computationally expensive. In this work, we propose contextual similarity distillation, a novel approach that explicitly estimates the variance of an ensemble of deep neural networks with a single model, without ever learning or evaluating such an ensemble in the first place. Our method builds on the predictable learning dynamics of wide neural networks, governed by the neural tangent kernel, to derive an efficient approximation of the predictive variance of an infinite ensemble. Specifically, we reinterpret the computation of ensemble variance as a supervised regression problem with kernel similarities as regression targets. The resulting model can estimate predictive variance at inference time with a single forward pass, and can make use of unlabeled target-domain data or data augmentations to refine its uncertainty estimates. We empirically validate our method across a variety of out-of-distribution detection benchmarks and sparse-reward reinforcement learning environments. We find that our single-model method performs competitively and sometimes superior to ensemble-based baselines and serves as a reliable signal for efficient exploration. These results, we believe, position contextual similarity distillation as a principled and scalable alternative for uncertainty quantification in reinforcement learning and general deep learning.
