Table of Contents
Fetching ...

Improving Discrete Optimisation Via Decoupled Straight-Through Estimator

Rushi Shah, Mingyuan Yan, Michael Curtis Mozer, Dianbo Liu

TL;DR

Decoupled Straight-Through (Decoupled ST) is proposed, a minimal modification that introduces separate temperatures for the forward pass ($\tau_f$) and the backward pass ($\tau_b$), enabling independent tuning of exploration and gradient dispersion.

Abstract

The Straight-Through Estimator (STE) is the dominant method for training neural networks with discrete variables, enabling gradient-based optimisation by routing gradients through a differentiable surrogate. However, existing STE variants conflate two fundamentally distinct concerns: forward-pass stochasticity, which controls exploration and latent space utilisation, and backward-pass gradient dispersion i.e how learning signals are distributed across categories. We show that these concerns are qualitatively different and that tying them to a single temperature parameter leaves significant performance gains untapped. We propose Decoupled Straight-Through (Decoupled ST), a minimal modification that introduces separate temperatures for the forward pass ($τ_f$) and the backward pass ($τ_b$). This simple change enables independent tuning of exploration and gradient dispersion. Across three diverse tasks (Stochastic Binary Networks, Categorical Autoencoders, and Differentiable Logic Gate Networks), Decoupled ST consistently outperforms Identity STE, Softmax STE, and Straight-Through Gumbel-Softmax. Crucially, optimal $(τ_f, τ_b)$ configurations lie far off the diagonal $τ_f = τ_b$, confirming that the two concerns do require different answers and that single-temperature methods are fundamentally constrained.

Improving Discrete Optimisation Via Decoupled Straight-Through Estimator

TL;DR

Decoupled Straight-Through (Decoupled ST) is proposed, a minimal modification that introduces separate temperatures for the forward pass () and the backward pass (), enabling independent tuning of exploration and gradient dispersion.

Abstract

The Straight-Through Estimator (STE) is the dominant method for training neural networks with discrete variables, enabling gradient-based optimisation by routing gradients through a differentiable surrogate. However, existing STE variants conflate two fundamentally distinct concerns: forward-pass stochasticity, which controls exploration and latent space utilisation, and backward-pass gradient dispersion i.e how learning signals are distributed across categories. We show that these concerns are qualitatively different and that tying them to a single temperature parameter leaves significant performance gains untapped. We propose Decoupled Straight-Through (Decoupled ST), a minimal modification that introduces separate temperatures for the forward pass () and the backward pass (). This simple change enables independent tuning of exploration and gradient dispersion. Across three diverse tasks (Stochastic Binary Networks, Categorical Autoencoders, and Differentiable Logic Gate Networks), Decoupled ST consistently outperforms Identity STE, Softmax STE, and Straight-Through Gumbel-Softmax. Crucially, optimal configurations lie far off the diagonal , confirming that the two concerns do require different answers and that single-temperature methods are fundamentally constrained.

Paper Structure

This paper contains 30 sections, 9 equations, 6 figures, 1 table, 2 algorithms.

Figures (6)

  • Figure 1: Gradient Flow Comparison for different gradient estimation methods.(a) In the continuous setting, node $\bm{z}$ is a deterministic variable, and gradients can be propagated back through $\bm{z}$ and $\mathcal{L}(\bm{z})$ directly using the chain rule. (b) When $\bm{z}$ represents a discrete categorical variable, the sampling/argmax process from $\bm{p}$ breaks the backpropagation path. (c) Standard STE, where $\partial \bm{z} / \partial \bm{p}$ is approximated as 1 during the backward pass, allowing gradients to flow through non-differentiable stochastic nodes. Different choices for $\boldsymbol{F}$ yields different STE variants. (d) ST-GS with temperature, scaling logits by a single temperature $\tau$ for both forward and backward passes after injecting stochasticity by adding a Gumbel noise sample. (e) Decoupled ST uses separate temperatures: $\tau_f$ controls forward-pass stochasticity while $\tau_b$ controls backward-pass gradient dispersion.
  • Figure 2: Stochastic Binary Networks on FashionMNIST. (a, b) Decoupled ST with $\tau_f = 0.1$ and $\tau_b = 0.7$ achieves faster convergence and higher final accuracy than all baselines. The optimal configuration lies off the diagonal ($\tau_f \neq \tau_b$), confirming the benefit of decoupling.
  • Figure 3: Categorical Autoencoder on MNIST (4 latents $\times$ 8 classes). (a) Decoupled ST achieves the lowest validation loss and consistent convergence across all seeds, whereas baselines fail to converge reliably despite extensive temperature and learning rate tuning (error bands omitted due to high variance). (b) Latent space utilization, measured by perplexity (the effective number of codes used), shows that Decoupled ST achieves near-maximum utilization ($\approx 7.9$ out of 8), while baselines severely underutilize the available codes.
  • Figure 4: Differentiable Logic Gate Networks on CIFAR10. (a, b) Decoupled ST with $\tau_f = 0.1$ and $\tau_b = 2.0$ achieves faster convergence and higher final accuracy than all baselines. The low forward temperature encourages near-deterministic gate selection for stable training, while the high backward temperature ensures broad gradient flow across gate options.
  • Figure 5: Gate distribution in Differentiable Logic Gate Networks. Distribution across the 16 logic gate types after training, where lighter colors indicate more uniform usage. Decoupled ST achieves the most balanced distribution, measured by perplexity (higher is more uniform, max is 16), indicating effective gradient dispersion across the combinatorial gate space.
  • ...and 1 more figures