Table of Contents
Fetching ...

Disentangled (Un)Controllable Features

Jacob E. Kooi, Mark Hoogendoorn, Vincent François-Lavet

TL;DR

This work introduces a disentangled latent representation for high-dimensional MDPs by partitioning the latent state into controllable $z^c$ and uncontrollable $z^u$ components. It combines an action-conditioned forward predictor for $z^c$, a state-only forward predictor for $z^u$, a contrastive loss to avoid representation collapse, and an adversarial loss to minimize information leakage from $z^u$ into $z^c$, enabling planning directly in the controllable latent. The approach is validated across three environment types, showing interpretable latent separation and competitive downstream learning performance, with planning in the controllable subspace providing practical advantages in unseen mazes. These results point toward interpretable, task-relevant latent representations that support planning and potential causal reasoning in RL, with emphasis on robust disentanglement and generalization to complex environments.

Abstract

In the context of MDPs with high-dimensional states, downstream tasks are predominantly applied on a compressed, low-dimensional representation of the original input space. A variety of learning objectives have therefore been used to attain useful representations. However, these representations usually lack interpretability of the different features. We present a novel approach that is able to disentangle latent features into a controllable and an uncontrollable partition. We illustrate that the resulting partitioned representations are easily interpretable on three types of environments and show that, in a distribution of procedurally generated maze environments, it is feasible to interpretably employ a planning algorithm in the isolated controllable latent partition.

Disentangled (Un)Controllable Features

TL;DR

This work introduces a disentangled latent representation for high-dimensional MDPs by partitioning the latent state into controllable and uncontrollable components. It combines an action-conditioned forward predictor for , a state-only forward predictor for , a contrastive loss to avoid representation collapse, and an adversarial loss to minimize information leakage from into , enabling planning directly in the controllable latent. The approach is validated across three environment types, showing interpretable latent separation and competitive downstream learning performance, with planning in the controllable subspace providing practical advantages in unseen mazes. These results point toward interpretable, task-relevant latent representations that support planning and potential causal reasoning in RL, with emphasis on robust disentanglement and generalization to complex environments.

Abstract

In the context of MDPs with high-dimensional states, downstream tasks are predominantly applied on a compressed, low-dimensional representation of the original input space. A variety of learning objectives have therefore been used to attain useful representations. However, these representations usually lack interpretability of the different features. We present a novel approach that is able to disentangle latent features into a controllable and an uncontrollable partition. We illustrate that the resulting partitioned representations are easily interpretable on three types of environments and show that, in a distribution of procedurally generated maze environments, it is feasible to interpretably employ a planning algorithm in the isolated controllable latent partition.
Paper Structure (39 sections, 13 equations, 15 figures, 1 algorithm)

This paper contains 39 sections, 13 equations, 15 figures, 1 algorithm.

Figures (15)

  • Figure 1: Visualization in a maze environment of four random pixel observations $s \in \mathbb{R}^{48\times48}$ (left) and the encoded observations $z = f(s;\theta_{enc}) \space \forall s \in \mathcal{S}$ (right). On the right, we can see the disentanglement of the controllable latent $z^{c} \in \mathbb{R}^{2}$ on the horizontal axes, and the uncontrollable latent $z^{u} \in \mathbb{R}^{1}$ on the vertical axis. The encoder is trained on high-dimensional tuples $(s_{t}, a_{t}, r_{t}, s_{t+1})$, sampled from a replay buffer $\mathcal{B}$, gathered from random trajectories in the four maze environments shown on the left. All possible states in all four mazes are encoded and plotted with the transition prediction for each possible action, revealing a clear disentanglement between the controllable latents (agent x-y position) and the uncontrollable latent (wall architecture). Note that all samples are taken from the same buffer, filled with samples from all four mazes.
  • Figure 2: Overview of the disentangling architecture, with dashed lines representing gradient propagation and green rectangles representing parameterized prediction functions. An observation $s_{t}$ is encoded into a latent representation consisting of two parts; $z^{c}_{t}$ and $z^{u}_{t}$, which represent controllable and uncontrollable features respectively. These separated representations are then independently used to make action-conditioned, state-only and adversarial predictions in order to provide gradients to the encoder that disentangle the latent representation $z_{t}$ into controllable ($z^{c}_{t}$) and uncontrollable ($z^{u}_{t}$) partitions.
  • Figure 3: Visualization of the latent feature disentanglement in the catcher environment after 200k training iterations, with $z_{t} = f(s_{t};\theta_{enc})$$\in \mathbb{R}^{2} +\mathbb{R}^{6\times 6}$. In (a) and (b), the left column shows $z^{c}_{t}$, the middle column is a feature map representing $z^{u}_{t}$ and the right column is the pixel state $s_{t}$. The dashed lines separate observations where the ball position or the paddle position is kept fixed for illustration purposes. $z^{c}$ tracks the agent position while $z^{u}$ tracks the falling ball. In b), note that even when having a two-dimensional controllable state (only 1 is needed, see Appendix \ref{['appendix:catcher']}), the adversarial loss in b) makes sure that distinct ball positions have a negligible effect on $z^{c}$ (left column), even when the high-level features of the agent and the ball might be hard to distinguish.
  • Figure 4: A plot of the latent representation for all observations in a single randomly sampled maze when training with the aforementioned losses (a), substituting the action-conditioned forward-prediction loss $\mathcal{L}_{c}$ for an inverse-prediction loss $\mathcal{L}_{inv}$ (b) and when end-to-end updating the encoder with only the Q-loss $\mathcal{L}_{Q}$ from DDQN for 500k iterations (c). The left column shows the controllable latent $z^{c}_{t} \in \mathbb{R}^{2}$ with the current state in blue, the remaining states in red, and the predicted movement due to actions as different colored bars for each individual action. The middle column shows the uncontrollable latent $z^{u}_{t} \in \mathbb{R}^{6 \times 6}$ and the right column shows the original state $s_{t} \in \mathbb{R}^{48 \times 48}$. Evidently, the controllable representations in (b) and (c) lack disentanglement and interpretability. Furthermore, the representation in (c) seems to have very little structure at all, showing that a representation that is optimized without prior structural incentives will often represent a black box.
  • Figure 5: Performance of different (pre)trained representations on the random maze environment, measured as a mean (full line) and standard error (shaded area) over 5 seeds. The 'Interpretable' setting uses an encoder pre-trained with 50k iterations to acquire a representation as in Fig. \ref{['fig:final_reps']}, after which the encoder is frozen and a Q-network is trained on top with DDQN for 500k iterations. The 'Interpretable + Planning' curve is similar to the 'Interpretable' setting but uses DDQN with a planning algorithm in the controllable partition of the latent space with a depth of 3. The 'DDQN' setting uses an encoder trained end-to-end with only DDQN for 500k iterations and the 'Inverse Prediction' setting is equal to the 'Interpretable' setting but has an encoder pre-trained with $\mathcal{L}_{inv}$ instead of $\mathcal{L}_{c}$.
  • ...and 10 more figures