Efficient NeRF Optimization -- Not All Samples Remain Equally Hard
Juuso Korhonen, Goutham Rangu, Hamed R. Tavakoli, Juho Kannala
TL;DR
This work tackles the high computational burden of NeRF training by identifying the backward pass as a primary bottleneck and introducing online hard sample mining. The method uses a two-forward-one-backward scheme: a first forward pass in inference mode locates hard samples via the propagated pixel loss, then a second forward pass builds the full computational graph, backpropagating only over a dynamically sized hard minibatch. When applied to Instant-NGP, this approach yields about a 2x speedup to reach the same PSNR, an average improvement of ~1 dB in PSNR for a given wall-clock time, and roughly 40% memory savings by avoiding graph construction for easy samples. The technique is hyperparameter-free, interfaces with the network module, and is broadly applicable to NeRF variants, offering practical gains for 3D reconstruction and rendering tasks.
Abstract
We propose an application of online hard sample mining for efficient training of Neural Radiance Fields (NeRF). NeRF models produce state-of-the-art quality for many 3D reconstruction and rendering tasks but require substantial computational resources. The encoding of the scene information within the NeRF network parameters necessitates stochastic sampling. We observe that during the training, a major part of the compute time and memory usage is spent on processing already learnt samples, which no longer affect the model update significantly. We identify the backward pass on the stochastic samples as the computational bottleneck during the optimization. We thus perform the first forward pass in inference mode as a relatively low-cost search for hard samples. This is followed by building the computational graph and updating the NeRF network parameters using only the hard samples. To demonstrate the effectiveness of the proposed approach, we apply our method to Instant-NGP, resulting in significant improvements of the view-synthesis quality over the baseline (1 dB improvement on average per training time, or 2x speedup to reach the same PSNR level) along with approx. 40% memory savings coming from using only the hard samples to build the computational graph. As our method only interfaces with the network module, we expect it to be widely applicable.
