Table of Contents
Fetching ...

Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks

Minyoung Huh, Brian Cheung, Pulkit Agrawal, Phillip Isola

TL;DR

This work tackles training instability in vector-quantized networks using straight-through estimation by identifying misalignment between embedding and codebook distributions as the root cause of index collapse. It introduces three techniques—affine re-parameterization of code-vectors, alternating optimization, and a synchronized commitment update—to better align distributions and reduce gradient estimation error. Empirical results across AlexNet, ResNet, and ViT for classification and across CIFAR10/CelebA with MaskGIT for generative modeling demonstrate improved codebook utilization, reduced sparsity, and enhanced performance. These contributions offer practical, mathematically grounded strategies to stabilize discrete latent representations in deep networks and provide insight into the optimization dynamics of VQNs.

Abstract

This work examines the challenges of training neural networks using vector quantization using straight-through estimation. We find that a primary cause of training instability is the discrepancy between the model embedding and the code-vector distribution. We identify the factors that contribute to this issue, including the codebook gradient sparsity and the asymmetric nature of the commitment loss, which leads to misaligned code-vector assignments. We propose to address this issue via affine re-parameterization of the code vectors. Additionally, we introduce an alternating optimization to reduce the gradient error introduced by the straight-through estimation. Moreover, we propose an improvement to the commitment loss to ensure better alignment between the codebook representation and the model embedding. These optimization methods improve the mathematical approximation of the straight-through estimation and, ultimately, the model performance. We demonstrate the effectiveness of our methods on several common model architectures, such as AlexNet, ResNet, and ViT, across various tasks, including image classification and generative modeling.

Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks

TL;DR

This work tackles training instability in vector-quantized networks using straight-through estimation by identifying misalignment between embedding and codebook distributions as the root cause of index collapse. It introduces three techniques—affine re-parameterization of code-vectors, alternating optimization, and a synchronized commitment update—to better align distributions and reduce gradient estimation error. Empirical results across AlexNet, ResNet, and ViT for classification and across CIFAR10/CelebA with MaskGIT for generative modeling demonstrate improved codebook utilization, reduced sparsity, and enhanced performance. These contributions offer practical, mathematically grounded strategies to stabilize discrete latent representations in deep networks and provide insight into the optimization dynamics of VQNs.

Abstract

This work examines the challenges of training neural networks using vector quantization using straight-through estimation. We find that a primary cause of training instability is the discrepancy between the model embedding and the code-vector distribution. We identify the factors that contribute to this issue, including the codebook gradient sparsity and the asymmetric nature of the commitment loss, which leads to misaligned code-vector assignments. We propose to address this issue via affine re-parameterization of the code vectors. Additionally, we introduce an alternating optimization to reduce the gradient error introduced by the straight-through estimation. Moreover, we propose an improvement to the commitment loss to ensure better alignment between the codebook representation and the model embedding. These optimization methods improve the mathematical approximation of the straight-through estimation and, ultimately, the model performance. We demonstrate the effectiveness of our methods on several common model architectures, such as AlexNet, ResNet, and ViT, across various tasks, including image classification and generative modeling.
Paper Structure (31 sections, 31 equations, 11 figures, 10 tables)

This paper contains 31 sections, 31 equations, 11 figures, 10 tables.

Figures (11)

  • Figure 1: Illustration of internal codebook covariate shift: During training, the embedding distribution $\mathcal{P}_{\mathbf{z}}$ drifts from initialization. When the model undergoes a distributional shift, the codebook $\mathcal{C}_{\mathbf{z}}$ (blue) is misaligned with $\mathcal{P}_{\mathbf{z}}$. The code-vectors that have assignments are denoted with $\mathcal{Q}_{\mathbf{z}}$ (red) and initialized to overlap with $\mathcal{C}_{\mathbf{z}}$. With training, $\mathcal{Q}_{\mathbf{z}}$ diverge and are misaligned with $\mathcal{P}_{\mathbf{z}}$. Code-vectors without assignment do not receive gradients and are no longer trained, which leads to bifurcation in the codebook distribution and ultimately leads to index collapse.
  • Figure 2: Divergence vs Accuracy: We visualize the divergence between $\mathcal{P}_z$, $\mathcal{Q}_z$ and $\mathcal{C}_{\mathbf{z}}$ on ResNet18 during training. ResNet18 is trained to solve ImageNet100 classification and the codes are initialized using K-means. We sample $2048$ embedding the vectors from $\mathcal{P}_{\mathbf{z}}$ and use the full $\mathcal{C}_{\mathbf{z}}$ and $\mathcal{Q}_{\mathbf{z}}$. We embed the vectors associated to the 10th training iteration using tSNE. We also compute distribution shifts throughout training by computing the histogram on the PCA projections. Here the lighter color indicates early iteration in training. The standard approach results in a bifurcation of the codebook. On the right, we show the result of our method using affine re-parameterization, which leads to better distribution matching.
  • Figure 3: Codebook update dynamics on toy experiment: Optimization dynamics of vector-quantization on a toy setup. The experiment above uses a single code vector with a stationary target for $\mathbf{z}_e$ (red). A euclidean loss is computed with respect to the code-vector $\mathbf{z}_q$ (blue), and the resulting gradient is used to update the embedding $\mathbf{z}_e$ using the straight-through approximation (black). A commitment loss is applied to the $\mathbf{z}_q$ and $\mathbf{z}_e$ using $l_2$ distance. All methods are optimized using SGD with the same fixed learning rate of $0.1$. $\mathsf{No VQ}$ optimizes the embedding without the quantization function. $\mathsf{Joint}$ optimizes the embedding with $\mathcal{L}_{\mathsf{task}}$ and $\mathcal{L}_{\mathsf{commit}}$ together. $\mathsf{Alternated}$ uses first optimizes the codebook assignment and then optimizes the model with the task loss, with single iteration for each step. $\mathsf{Lookahead}$ objective predicts the trajectory of $\mathbf{z}_e$ and updates $\mathbf{z}_q$ towards it. The trajectory creates a large spiral for standard VQ training due to the straight-through approximation error. The alternated optimization minimizes this approximation error, reducing the extent of the spiral. Note that the approximation error is also caused by the code-vector representation, which is a historical moving average of the embedding with a delay. The lookahead optimizer reuses the gradient from $\mathcal{L}_{\mathsf{task}}$ to better synchronize the code-vector representation and accelerate convergence.
  • Figure 4: MaskGIT FID training curves: MaskGIT chang2022maskgit trained on CelebA with only reconstruction loss. We report rFID (left) and FID (right) training curves. We use a slimmed-down version of MaskGIT: VQGAN using $32$ channels instead of $128$ and transformer using $8$ blocks instead of $24$.
  • Figure 5: Warmup improves perplexity: We plot the perplexity and test accuracy for methods trained with and without a linear warmup. All methods were initialized with K-means.
  • ...and 6 more figures