Table of Contents
Fetching ...

Understanding Straight-Through Estimator in Training Activation Quantized Neural Nets

Penghang Yin, Jiancheng Lyu, Shuai Zhang, Stanley Osher, Yingyong Qi, Jack Xin

TL;DR

The paper addresses the gradient-vanishing challenge in activation-quantized networks by formulating a coarse gradient via straight-through estimators (STE) and analyzing its effect on training a two-linear-layer CNN with binary activation under Gaussian inputs. It proves that properly chosen STEs, notably vanilla ReLU and clipped ReLU, produce descent directions for the population loss and converge to critical points, while the identity STE can cause non-descent and instability. Theoretical results are complemented by experiments on MNIST and CIFAR-10 showing clipped ReLU often yields the best performance for deeper networks and that identity STE can be unstable near good minima. These findings provide a principled basis for selecting STEs in quantized networks and highlight potential risks with poorly matched surrogate gradients.

Abstract

Training activation quantized neural networks involves minimizing a piecewise constant function whose gradient vanishes almost everywhere, which is undesirable for the standard back-propagation or chain rule. An empirical way around this issue is to use a straight-through estimator (STE) (Bengio et al., 2013) in the backward pass only, so that the "gradient" through the modified chain rule becomes non-trivial. Since this unusual "gradient" is certainly not the gradient of loss function, the following question arises: why searching in its negative direction minimizes the training loss? In this paper, we provide the theoretical justification of the concept of STE by answering this question. We consider the problem of learning a two-linear-layer network with binarized ReLU activation and Gaussian input data. We shall refer to the unusual "gradient" given by the STE-modifed chain rule as coarse gradient. The choice of STE is not unique. We prove that if the STE is properly chosen, the expected coarse gradient correlates positively with the population gradient (not available for the training), and its negation is a descent direction for minimizing the population loss. We further show the associated coarse gradient descent algorithm converges to a critical point of the population loss minimization problem. Moreover, we show that a poor choice of STE leads to instability of the training algorithm near certain local minima, which is verified with CIFAR-10 experiments.

Understanding Straight-Through Estimator in Training Activation Quantized Neural Nets

TL;DR

The paper addresses the gradient-vanishing challenge in activation-quantized networks by formulating a coarse gradient via straight-through estimators (STE) and analyzing its effect on training a two-linear-layer CNN with binary activation under Gaussian inputs. It proves that properly chosen STEs, notably vanilla ReLU and clipped ReLU, produce descent directions for the population loss and converge to critical points, while the identity STE can cause non-descent and instability. Theoretical results are complemented by experiments on MNIST and CIFAR-10 showing clipped ReLU often yields the best performance for deeper networks and that identity STE can be unstable near good minima. These findings provide a principled basis for selecting STEs in quantized networks and highlight potential risks with poorly matched surrogate gradients.

Abstract

Training activation quantized neural networks involves minimizing a piecewise constant function whose gradient vanishes almost everywhere, which is undesirable for the standard back-propagation or chain rule. An empirical way around this issue is to use a straight-through estimator (STE) (Bengio et al., 2013) in the backward pass only, so that the "gradient" through the modified chain rule becomes non-trivial. Since this unusual "gradient" is certainly not the gradient of loss function, the following question arises: why searching in its negative direction minimizes the training loss? In this paper, we provide the theoretical justification of the concept of STE by answering this question. We consider the problem of learning a two-linear-layer network with binarized ReLU activation and Gaussian input data. We shall refer to the unusual "gradient" given by the STE-modifed chain rule as coarse gradient. The choice of STE is not unique. We prove that if the STE is properly chosen, the expected coarse gradient correlates positively with the population gradient (not available for the training), and its negation is a descent direction for minimizing the population loss. We further show the associated coarse gradient descent algorithm converges to a critical point of the population loss minimization problem. Moreover, we show that a poor choice of STE leads to instability of the training algorithm near certain local minima, which is verified with CIFAR-10 experiments.

Paper Structure

This paper contains 14 sections, 17 theorems, 133 equations, 4 figures, 2 tables, 1 algorithm.

Key Result

Lemma 1

If ${\bm{w}}\neq \mathbf{0}_n$, the population loss $f({\bm{v}},{\bm{w}})$ is given by In addition, $f({\bm{v}},{\bm{w}}) = \frac{1}{8}({\bm{v}}^*)^\top ({\bm{I}}_m + \bm{1}_m\bm{1}_m^\top ){\bm{v}}^*$ for ${\bm{w}} = \mathbf{0}_n$.

Figures (4)

  • Figure 1: The plots of the empirical loss moving by one step in the direction of negative coarse gradient v.s. the learning rate (step size) $\eta$ for different sample sizes.
  • Figure 2: When initialized with weights (good minima) produced by the vanilla (orange) and clipped (blue) ReLUs on ResNet-20 with 4-bit activations, the coarse gradient descent using the identity STE ends up being repelled from there. The learning rate is set to $10^{-5}$ until epoch 20.
  • Figure 3: The plots of 2-bit quantized ReLU $\sigma_{\alpha}(x)$ (with $2^2=4$ quantization levels including 0) and the associated clipped ReLU $\tilde{\sigma}_{\alpha}(x)$. $\alpha$ is the resolution determined in advance of the network training.
  • Figure 4: When initialized with the weights produced by the clipped ReLU STE on ResNet-20 with 2-bit activations (88.38% validation accuracy), the coarse gradient descent using the ReLU STE with $10^{-5}$ learning rate is not stable there, and both classification and training errors begin to increase.

Theorems & Definitions (36)

  • Lemma 1
  • Lemma 2
  • Proposition 1
  • Lemma 3
  • Theorem 1: Convergence
  • Remark 1
  • Remark 2
  • Lemma 4
  • Lemma 5
  • Lemma 6
  • ...and 26 more