Provably Efficient Reinforcement Learning with Multinomial Logit Function Approximation
Long-Fei Li, Yu-Jie Zhang, Peng Zhao, Zhi-Hua Zhou
TL;DR
This work studies reinforcement learning with multinomial logit function approximation to ensure valid transition distributions in MNL mixture MDPs. It develops two algorithms: the statistically efficient UCRL-MNL-LL and the computationally efficient UCRL-MNL-OL, achieving a regret of $\tilde{O}(d H^2 \sqrt{K} + \kappa^{-1} d^2 H^2)$ with the dominant term made κ-independent, and constant per-episode computation for the latter. TheStatistical contribution uses Bernstein-type concentration and self-concordant-like properties of the log-loss to form tighter confidence sets; the computational contribution relies on online mirror descent and a closed-form second-order bonus to avoid non-convex optimization. A first lower bound for MNL mixture MDPs is derived via a reduction to logistic bandits, establishing a fundamental limit that supports the claimed optimality in $d$ and $K$ up to $H^{1/2}$ factors. Together, these results advance the theoretical understanding of RL with non-linear, probabilistic transitions and provide practical algorithms for large state spaces.
Abstract
We study a new class of MDPs that employs multinomial logit (MNL) function approximation to ensure valid probability distributions over the state space. Despite its significant benefits, incorporating the non-linear function raises substantial challenges in both statistical and computational efficiency. The best-known result of Hwang and Oh [2023] has achieved an $\widetilde{\mathcal{O}}(κ^{-1}dH^2\sqrt{K})$ regret upper bound, where $κ$ is a problem-dependent quantity, $d$ is the feature dimension, $H$ is the episode length, and $K$ is the number of episodes. However, we observe that $κ^{-1}$ exhibits polynomial dependence on the number of reachable states, which can be as large as the state space size in the worst case and thus undermines the motivation for function approximation. Additionally, their method requires storing all historical data and the time complexity scales linearly with the episode count, which is computationally expensive. In this work, we propose a statistically efficient algorithm that achieves a regret of $\widetilde{\mathcal{O}}(dH^2\sqrt{K} + κ^{-1}d^2H^2)$, eliminating the dependence on $κ^{-1}$ in the dominant term for the first time. We then address the computational challenges by introducing an enhanced algorithm that achieves the same regret guarantee but with only constant cost. Finally, we establish the first lower bound for this problem, justifying the optimality of our results in $d$ and $K$.
