The Ray Tracing Sampler: Bayesian Sampling of Neural Networks for Everyone
Peter Behroozi
TL;DR
This work introduces the Ray Tracing Sampler, a Bayesian MCMC method that propagates parameter-space rays through a medium with refractive index $n(x)=\mathcal{L}(x)^{1/(D-1)}$, yielding constant-speed trajectories whose radiance tracks the likelihood. By conserving radiance and étendue, the method achieves fair sampling even with imperfect integrators and across likelihood barriers, while remaining highly resilient to stochastic gradients. The framework unifies prior methods (HMC, Langevin, Gibbs, Metropolis, etc.) as special cases under generalized ray tracing, and demonstrates scalable posterior sampling for neural networks from thousands to 1.5 billion parameters, including GPT-2-scale models on consumer hardware. Empirically, ray tracing matches HMC in low-noise settings but outperforms it under stochastic gradients, enabling practical Bayesian uncertainty quantification for large-scale neural networks with notable implications for model reliability and architecture design.
Abstract
We derive a Markov Chain Monte Carlo sampler based on following ray paths in a medium where the refractive index $n(x)$ is a function of the desired likelihood $\mathcal{L}(x)$. The sampling method propagates rays at constant speed through parameter space, leading to orders of magnitude higher resilience to heating for stochastic gradients as compared to Hamiltonian Monte Carlo (HMC), as well as the ability to cross any likelihood barrier, including holes in parameter space. Using the ray tracing method, we sample the posterior distributions of neural network outputs for a variety of different architectures, up to the 1.5 billion-parameter GPT-2 (Generative Pre-trained Transformer 2) architecture, all on a single consumer-level GPU. We also show that prior samplers including traditional HMC, microcanonical HMC, Metropolis, Gibbs, and even Monte Carlo integration are special cases within a generalized ray tracing framework, which can sample according to an arbitrary weighting function. Public code and documentation for C, JAX, and PyTorch are available at https://bitbucket.org/pbehroozi/ray-tracing-sampler/src
