Table of Contents
Fetching ...

Overcoming Valid Action Suppression in Unmasked Policy Gradient Algorithms

Renos Zabounidis, Roy Siegelmann, Mohamad Qadri, Woojun Kim, Simon Stepputtis, Katia P. Sycara

TL;DR

This work identifies a distinct failure mode of unmasked training: it systematically suppresses valid actions at states the agent has not yet visited, and proves that for softmax policies with shared features, the probability of an action being invalid at visited states but valid at an unvisited state is bounded by exponential decay due to parameter sharing and the zero-sum identity of softmax logits.

Abstract

In reinforcement learning environments with state-dependent action validity, action masking consistently outperforms penalty-based handling of invalid actions, yet existing theory only shows that masking preserves the policy gradient theorem. We identify a distinct failure mode of unmasked training: it systematically suppresses valid actions at states the agent has not yet visited. This occurs because gradients pushing down invalid actions at visited states propagate through shared network parameters to unvisited states where those actions are valid. We prove that for softmax policies with shared features, when an action is invalid at visited states but valid at an unvisited state $s^*$, the probability $π(a \mid s^*)$ is bounded by exponential decay due to parameter sharing and the zero-sum identity of softmax logits. This bound reveals that entropy regularization trades off between protecting valid actions and sample efficiency, a tradeoff that masking eliminates. We validate empirically that deep networks exhibit the feature alignment condition required for suppression, and experiments on Craftax, Craftax-Classic, and MiniHack confirm the predicted exponential suppression and demonstrate that feasibility classification enables deployment without oracle masks.

Overcoming Valid Action Suppression in Unmasked Policy Gradient Algorithms

TL;DR

This work identifies a distinct failure mode of unmasked training: it systematically suppresses valid actions at states the agent has not yet visited, and proves that for softmax policies with shared features, the probability of an action being invalid at visited states but valid at an unvisited state is bounded by exponential decay due to parameter sharing and the zero-sum identity of softmax logits.

Abstract

In reinforcement learning environments with state-dependent action validity, action masking consistently outperforms penalty-based handling of invalid actions, yet existing theory only shows that masking preserves the policy gradient theorem. We identify a distinct failure mode of unmasked training: it systematically suppresses valid actions at states the agent has not yet visited. This occurs because gradients pushing down invalid actions at visited states propagate through shared network parameters to unvisited states where those actions are valid. We prove that for softmax policies with shared features, when an action is invalid at visited states but valid at an unvisited state , the probability is bounded by exponential decay due to parameter sharing and the zero-sum identity of softmax logits. This bound reveals that entropy regularization trades off between protecting valid actions and sample efficiency, a tradeoff that masking eliminates. We validate empirically that deep networks exhibit the feature alignment condition required for suppression, and experiments on Craftax, Craftax-Classic, and MiniHack confirm the predicted exponential suppression and demonstrate that feasibility classification enables deployment without oracle masks.
Paper Structure (63 sections, 6 theorems, 40 equations, 8 figures, 5 tables)

This paper contains 63 sections, 6 theorems, 40 equations, 8 figures, 5 tables.

Key Result

Theorem 1

Fix action $a$ and let $s^*$ be a state where $a$ is valid but $s^* \notin \mathcal{S}_{\mathrm{vis}}$. This is a first valid occurrence. Assume uniform initialization $\pi_0(j \mid s^*) = 1/n$. Suppose Conditions (i) and (ii) hold throughout training prior to visiting $s^*$. Let $\beta \geq 0$ be t and the cumulative suppression $K_T = \sum_{\tau=0}^{T-1} \kappa_\tau$. Then after $T$ gradient ste

Figures (8)

  • Figure 1: Suppression mechanism illustrated on a staircase corridor. (a) The agent must traverse a corridor to reach a staircase at $s^*$, where descend is the only valid goal-reaching action. At $T\!=\!0$, all actions are equally likely ($1/n$). After $N$ training steps, gradient updates at visited states reinforce right and suppress all other actions, including descend. Because parameters are shared, this suppression propagates to $s^*$ before the agent arrives. The descend action is suppressed at the one state where it is needed. (b) Upper bound on $\pi(\textsf{descend} \mid s^*)$ from Theorem \ref{['thm:prob_suppression']}. Without entropy regularization ($\beta\!=\!0$), the probability decays exponentially. With entropy regularization ($\beta\!>\!0$), a floor emerges but cannot eliminate suppression (Eq. \ref{['eq:entropy_sandwich']}).
  • Figure 2: Architecture overview. A shared encoder $\phi(s)$ feeds three heads: a classification head (blue) predicting action validity $\hat{\nu}(s,a)$ via sigmoid, a policy head (green) producing oracle-masked actions $\pi^{\text{oracle}}_\theta(a \mid s)$, and a value head (orange) estimating $V(s)$. The predicted validity constructs a predicted policy $\pi^{\text{pred}}_\theta(a \mid s)$. The KL divergence between $\pi^{\text{oracle}}$ and $\pi^{\text{pred}}$ yields per-action weights $w_a$ for the classification loss (Eq. \ref{['eq:kl_weight']}). The total loss combines PPO and weighted classification objectives (Eq. \ref{['eq:total_loss']}). Dashed arrows indicate loss aggregation. At deployment, there are two modes: (i) with oracle masks available, classification heads can be discarded; (ii) without oracle masks, the learned predictor provides validity estimates for deployment.
  • Figure 3: Action suppression in Craftax and MiniHack Corridor-5 (PPO-Hybrid). Left: probability of rare critical actions at valid states (log scale). Right: fraction of timesteps with valid actions selected. Unmasked training (C2, red) exhibits exponential suppression, while oracle masking (C1, blue) prevents collapse. Dashed line: uniform initialization.
  • Figure 4: Representational correlation between valid and invalid states (PPO-Hybrid). Oracle masking preserves highly entangled representations, while KL-balanced classification induces validity-aware features without reintroducing policy-level suppression.
  • Figure 5: Environments used in our experiments.
  • ...and 3 more figures

Theorems & Definitions (13)

  • Theorem 1: Probability suppression at first valid occurrence
  • Lemma 2: Expected logit update
  • proof
  • Lemma 3: Zero-sum identity
  • proof
  • Lemma 4: Invalid-action dominance gap
  • proof
  • Lemma 5: Expected parameter update
  • proof
  • Proposition 6: Generalization of logit suppression
  • ...and 3 more