Table of Contents
Fetching ...

Contextual Similarity Distillation: Ensemble Uncertainties with a Single Model

Moritz A. Zanger, Pascal R. Van der Vaart, Wendelin Böhmer, Matthijs T. J. Spaan

TL;DR

The paper tackles the challenge of reliable uncertainty quantification in large deep networks, particularly for reinforcement learning, where ensembles are effective but costly. It introduces Contextual Similarity Distillation (CSD), which reinterprets the predictive variance of a random-initialization ensemble as a structured regression problem using NTK-based kernel similarities, enabling a single forward pass to estimate uncertainty. By incorporating a context variable and allowing unlabeled data or augmentations, CSD can produce uncertainty estimates for arbitrary queries with competitive accuracy relative to ensembles and Bayesian baselines, at reduced computational cost. Empirical results on distribution shift detection and VizDoom exploration demonstrate strong performance and tangible benefits for exploration and safety-aware decision-making, positioning CSD as a scalable alternative for uncertainty quantification in deep learning and RL.

Abstract

Uncertainty quantification is a critical aspect of reinforcement learning and deep learning, with numerous applications ranging from efficient exploration and stable offline reinforcement learning to outlier detection in medical diagnostics. The scale of modern neural networks, however, complicates the use of many theoretically well-motivated approaches such as full Bayesian inference. Approximate methods like deep ensembles can provide reliable uncertainty estimates but still remain computationally expensive. In this work, we propose contextual similarity distillation, a novel approach that explicitly estimates the variance of an ensemble of deep neural networks with a single model, without ever learning or evaluating such an ensemble in the first place. Our method builds on the predictable learning dynamics of wide neural networks, governed by the neural tangent kernel, to derive an efficient approximation of the predictive variance of an infinite ensemble. Specifically, we reinterpret the computation of ensemble variance as a supervised regression problem with kernel similarities as regression targets. The resulting model can estimate predictive variance at inference time with a single forward pass, and can make use of unlabeled target-domain data or data augmentations to refine its uncertainty estimates. We empirically validate our method across a variety of out-of-distribution detection benchmarks and sparse-reward reinforcement learning environments. We find that our single-model method performs competitively and sometimes superior to ensemble-based baselines and serves as a reliable signal for efficient exploration. These results, we believe, position contextual similarity distillation as a principled and scalable alternative for uncertainty quantification in reinforcement learning and general deep learning.

Contextual Similarity Distillation: Ensemble Uncertainties with a Single Model

TL;DR

The paper tackles the challenge of reliable uncertainty quantification in large deep networks, particularly for reinforcement learning, where ensembles are effective but costly. It introduces Contextual Similarity Distillation (CSD), which reinterprets the predictive variance of a random-initialization ensemble as a structured regression problem using NTK-based kernel similarities, enabling a single forward pass to estimate uncertainty. By incorporating a context variable and allowing unlabeled data or augmentations, CSD can produce uncertainty estimates for arbitrary queries with competitive accuracy relative to ensembles and Bayesian baselines, at reduced computational cost. Empirical results on distribution shift detection and VizDoom exploration demonstrate strong performance and tangible benefits for exploration and safety-aware decision-making, positioning CSD as a scalable alternative for uncertainty quantification in deep learning and RL.

Abstract

Uncertainty quantification is a critical aspect of reinforcement learning and deep learning, with numerous applications ranging from efficient exploration and stable offline reinforcement learning to outlier detection in medical diagnostics. The scale of modern neural networks, however, complicates the use of many theoretically well-motivated approaches such as full Bayesian inference. Approximate methods like deep ensembles can provide reliable uncertainty estimates but still remain computationally expensive. In this work, we propose contextual similarity distillation, a novel approach that explicitly estimates the variance of an ensemble of deep neural networks with a single model, without ever learning or evaluating such an ensemble in the first place. Our method builds on the predictable learning dynamics of wide neural networks, governed by the neural tangent kernel, to derive an efficient approximation of the predictive variance of an infinite ensemble. Specifically, we reinterpret the computation of ensemble variance as a supervised regression problem with kernel similarities as regression targets. The resulting model can estimate predictive variance at inference time with a single forward pass, and can make use of unlabeled target-domain data or data augmentations to refine its uncertainty estimates. We empirically validate our method across a variety of out-of-distribution detection benchmarks and sparse-reward reinforcement learning environments. We find that our single-model method performs competitively and sometimes superior to ensemble-based baselines and serves as a reliable signal for efficient exploration. These results, we believe, position contextual similarity distillation as a principled and scalable alternative for uncertainty quantification in reinforcement learning and general deep learning.

Paper Structure

This paper contains 26 sections, 26 equations, 4 figures, 11 tables.

Figures (4)

  • Figure 1: Illustration of regression tasks with query-dependent NTK similarities as labels. The difference between the kernel prior function $\Theta(x,x)$ (dotted line) and the post-training regression function $g_{x_t}(x,\tilde{\theta}_\infty)$ matches exactly ensemble variance in $x_t$. Plots from left to right depict the same principle, but for different query points $x_t$.
  • Figure 2: Top Row: Variance of an ensemble of 100 randomly initialized neural networks on a 2D toy regression task. Red dots are training points. Bottom Row: Variance prediction by contextual similarity distillation (CSD) with a single model on the same regression task.
  • Figure 3: (Left): Visual observation in the VizDoom environment Kempka2016ViZDoom. (From Second Left to Right): Mean learning curves in variations of the MyWayHome VizDoom environment. Shaded regions are $90\%$ Student’s t confidence intervals from 10 seeds.
  • Figure 4: Left: Original Image. Right: Perturbed OOD Image.