Can We Really Learn One Representation to Optimize All Rewards?
Chongyi Zheng, Royina Karegoudra Jayanth, Benjamin Eysenbach
TL;DR
The paper scrutinizes forward-backward (FB) representation learning as a means to pretrain a single latent representation that enables zero-shot optimization for any reward in reinforcement learning. It shows that ground-truth FB representations require strict rank conditions and that the FB objective is a TD-like LSIF loss tied to a non-contractive FB Bellman operator, which can hinder convergence. To address these issues, the authors propose one-step FB, which fixes the behavioral policy and performs one-step policy improvement, learning forward and backward representations via a TD-one-step LSIF loss with orthonormal regularization. Empirically, one-step FB converges reliably in didactic and real-world tasks, yielding up to $10^5$ lower errors and about a 24% average gain in zero-shot performance across 8 state-based and 2 image-based domains, and providing a strong initialization for subsequent fine-tuning. The work offers a practical unsupervised pre-training method with solid theoretical grounding and demonstrated empirical benefits, while clarifying the limitations of universal FB representations for solving all rewards.
Abstract
As machine learning has moved towards leveraging large models as priors for downstream tasks, the community has debated the right form of prior for solving reinforcement learning (RL) problems. If one were to try to prefetch as much computation as possible, they would attempt to learn a prior over the policies for some yet-to-be-determined reward function. Recent work (forward-backward (FB) representation learning) has tried this, arguing that an unsupervised representation learning procedure can enable optimal control over arbitrary rewards without further fine-tuning. However, FB's training objective and learning behavior remain mysterious. In this paper, we demystify FB by clarifying when such representations can exist, what its objective optimizes, and how it converges in practice. We draw connections with rank matching, fitted Q-evaluation, and contraction mapping. Our analysis suggests a simplified unsupervised pre-training method for RL that, instead of enabling optimal control, performs one step of policy improvement. We call our proposed method $\textbf{one-step forward-backward representation learning (one-step FB)}$. Experiments in didactic settings, as well as in $10$ state-based and image-based continuous control domains, demonstrate that one-step FB converges to errors $10^5$ smaller and improves zero-shot performance by $+24\%$ on average. Our project website is available at https://chongyi-zheng.github.io/onestep-fb.
