Table of Contents
Fetching ...

Can Temporal-Difference and Q-Learning Learn Representation? A Mean-Field Theory

Yufeng Zhang, Qi Cai, Zhuoran Yang, Yongxin Chen, Zhaoran Wang

TL;DR

We address how overparameterized neural networks used in TD learning and Q-learning evolve their feature representations and converge to globally optimal solutions. By developing a mean-field framework in Wasserstein space, we show that TD and Q-learning globally minimize the MSPBE and drive the induced representation toward the optimal $Q^*$ with sublinear rates, even when the representation departs from its initialization (beyond the NTK regime). The analysis extends to soft Q-learning and policy-gradient style updates, establishing fixed-point characterizations and finite-width convergence guarantees. These results provide a principled understanding of representation learning in deep reinforcement learning and highlight a path to global convergence beyond NTK limits.

Abstract

Temporal-difference and Q-learning play a key role in deep reinforcement learning, where they are empowered by expressive nonlinear function approximators such as neural networks. At the core of their empirical successes is the learned feature representation, which embeds rich observations, e.g., images and texts, into the latent space that encodes semantic structures. Meanwhile, the evolution of such a feature representation is crucial to the convergence of temporal-difference and Q-learning. In particular, temporal-difference learning converges when the function approximator is linear in a feature representation, which is fixed throughout learning, and possibly diverges otherwise. We aim to answer the following questions: When the function approximator is a neural network, how does the associated feature representation evolve? If it converges, does it converge to the optimal one? We prove that, utilizing an overparameterized two-layer neural network, temporal-difference and Q-learning globally minimize the mean-squared projected Bellman error at a sublinear rate. Moreover, the associated feature representation converges to the optimal one, generalizing the previous analysis of Cai et al. (2019) in the neural tangent kernel regime, where the associated feature representation stabilizes at the initial one. The key to our analysis is a mean-field perspective, which connects the evolution of a finite-dimensional parameter to its limiting counterpart over an infinite-dimensional Wasserstein space. Our analysis generalizes to soft Q-learning, which is further connected to policy gradient.

Can Temporal-Difference and Q-Learning Learn Representation? A Mean-Field Theory

TL;DR

We address how overparameterized neural networks used in TD learning and Q-learning evolve their feature representations and converge to globally optimal solutions. By developing a mean-field framework in Wasserstein space, we show that TD and Q-learning globally minimize the MSPBE and drive the induced representation toward the optimal with sublinear rates, even when the representation departs from its initialization (beyond the NTK regime). The analysis extends to soft Q-learning and policy-gradient style updates, establishing fixed-point characterizations and finite-width convergence guarantees. These results provide a principled understanding of representation learning in deep reinforcement learning and highlight a path to global convergence beyond NTK limits.

Abstract

Temporal-difference and Q-learning play a key role in deep reinforcement learning, where they are empowered by expressive nonlinear function approximators such as neural networks. At the core of their empirical successes is the learned feature representation, which embeds rich observations, e.g., images and texts, into the latent space that encodes semantic structures. Meanwhile, the evolution of such a feature representation is crucial to the convergence of temporal-difference and Q-learning. In particular, temporal-difference learning converges when the function approximator is linear in a feature representation, which is fixed throughout learning, and possibly diverges otherwise. We aim to answer the following questions: When the function approximator is a neural network, how does the associated feature representation evolve? If it converges, does it converge to the optimal one? We prove that, utilizing an overparameterized two-layer neural network, temporal-difference and Q-learning globally minimize the mean-squared projected Bellman error at a sublinear rate. Moreover, the associated feature representation converges to the optimal one, generalizing the previous analysis of Cai et al. (2019) in the neural tangent kernel regime, where the associated feature representation stabilizes at the initial one. The key to our analysis is a mean-field perspective, which connects the evolution of a finite-dimensional parameter to its limiting counterpart over an infinite-dimensional Wasserstein space. Our analysis generalizes to soft Q-learning, which is further connected to policy gradient.

Paper Structure

This paper contains 27 sections, 22 theorems, 159 equations, 2 figures, 2 algorithms.

Key Result

Proposition 3.1

Let the initial distribution $\rho_0$ be the standard Gaussian distribution $N(0, I_D)$. Under certain regularity conditions, $\widehat{\rho}_{\lfloor t/\epsilon\rfloor}^{(m)}$ weakly converges to $\rho_t$ as $\epsilon\rightarrow 0^+$ and $m\rightarrow \infty$.

Figures (2)

  • Figure 1: We illustrate the first variation formula $\frac{{\mathrm{d}} \mathcal{W}_2(\rho_t, \rho^*)^2}{2} = - \inp{g(\cdot; \rho_t)}{v}_{\rho_t}$, where $v$ is the vector field corresponding to the geodesic that connects $\rho_t$ and $\rho^*$. See Lemma \ref{['lem:diff']} for details.
  • Figure 2: For any $0 \le t \le \min\{t^*, t_*\}$, \ref{['eq:lem-descent']} of Lemma \ref{['lem:descent']} holds and $\frac{{\mathrm{d}}}{{\mathrm{d}} t} \frac{\mathcal{W}_2(\rho_t, \rho_*)^2}{2} \le 0$.

Theorems & Definitions (23)

  • Proposition 3.1: Informal Version of Proposition \ref{['prop:discretization']}
  • Theorem 4.3
  • Lemma 4.4
  • Corollary 4.5
  • Lemma 5.1
  • Lemma 5.2
  • Theorem 6.2
  • Proposition 6.4
  • Lemma A.1
  • Lemma A.2
  • ...and 13 more