Learning Discretized Neural Networks under Ricci Flow
Jun Chen, Hanwen Chen, Mengmeng Wang, Guang Dai, Ivor W. Tsang, Yong Liu
TL;DR
This work addresses gradient mismatch in discretized neural networks by recasting training as a metric perturbation problem on a Riemannian manifold and proposes a geometric remedy using Ricci flow. The authors construct a Linearly Nearly Euclidean (LNE) manifold via an information-geometric framework and show that Ricci flow exponentially dampens metric perturbations, enabling stable training of discretized networks. They then derive practical gradient computations on the evolving LNE manifold, introducing strong and weak approximations to invert the LNE metric and implement RF-DNNs with discrete Ricci flow. Empirical results on CIFAR and ImageNet demonstrate improved stability and accuracy over STE-based methods across bit-widths and architectures. Overall, the paper provides a theoretically grounded, geometry-driven approach to training low-precision DNNs with competitive performance and stability advantages.
Abstract
In this paper, we study Discretized Neural Networks (DNNs) composed of low-precision weights and activations, which suffer from either infinite or zero gradients due to the non-differentiable discrete function during training. Most training-based DNNs in such scenarios employ the standard Straight-Through Estimator (STE) to approximate the gradient w.r.t. discrete values. However, the use of STE introduces the problem of gradient mismatch, arising from perturbations in the approximated gradient. To address this problem, this paper reveals that this mismatch can be interpreted as a metric perturbation in a Riemannian manifold, viewed through the lens of duality theory. Building on information geometry, we construct the Linearly Nearly Euclidean (LNE) manifold for DNNs, providing a background for addressing perturbations. By introducing a partial differential equation on metrics, i.e., the Ricci flow, we establish the dynamical stability and convergence of the LNE metric with the $L^2$-norm perturbation. In contrast to previous perturbation theories with convergence rates in fractional powers, the metric perturbation under the Ricci flow exhibits exponential decay in the LNE manifold. Experimental results across various datasets demonstrate that our method achieves superior and more stable performance for DNNs compared to other representative training-based methods.
