Table of Contents
Fetching ...

Rectifying Regression in Reinforcement Learning

Alex Ayoub, David Szepesvári, Alireza Bakhtiari, Csaba Szepesvári, Dale Schuurmans

TL;DR

This work investigates how the choice of regression objective affects value-based reinforcement learning, arguing that MAE-based losses (notably log-loss and a reparameterized cat-loss) align more closely with optimal decision-making than the traditional squared loss. The authors provide theoretical bounds showing faster, MAE-aligned convergence under certain problem structures, along with explicit negative results that illustrate limitations of MSE-based approaches. They introduce a reparameterized categorical cross-entropy loss that preserves the mean while enabling multi-class classification, and they demonstrate its empirical viability in linear batch RL via the inverted pendulum experiment. The findings suggest that selecting loss functions informed by MAE can improve policy quality and convergence, with potential implications for broader distributional RL methods and practical algorithm design.

Abstract

This paper investigates the impact of the loss function in value-based methods for reinforcement learning through an analysis of underlying prediction objectives. We theoretically show that mean absolute error is a better prediction objective than the traditional mean squared error for controlling the learned policy's suboptimality gap. Furthermore, we present results that different loss functions are better aligned with these different regression objectives: binary and categorical cross-entropy losses with the mean absolute error and squared loss with the mean squared error. We then provide empirical evidence that algorithms minimizing these cross-entropy losses can outperform those based on the squared loss in linear reinforcement learning.

Rectifying Regression in Reinforcement Learning

TL;DR

This work investigates how the choice of regression objective affects value-based reinforcement learning, arguing that MAE-based losses (notably log-loss and a reparameterized cat-loss) align more closely with optimal decision-making than the traditional squared loss. The authors provide theoretical bounds showing faster, MAE-aligned convergence under certain problem structures, along with explicit negative results that illustrate limitations of MSE-based approaches. They introduce a reparameterized categorical cross-entropy loss that preserves the mean while enabling multi-class classification, and they demonstrate its empirical viability in linear batch RL via the inverted pendulum experiment. The findings suggest that selecting loss functions informed by MAE can improve policy quality and convergence, with potential implications for broader distributional RL methods and practical algorithm design.

Abstract

This paper investigates the impact of the loss function in value-based methods for reinforcement learning through an analysis of underlying prediction objectives. We theoretically show that mean absolute error is a better prediction objective than the traditional mean squared error for controlling the learned policy's suboptimality gap. Furthermore, we present results that different loss functions are better aligned with these different regression objectives: binary and categorical cross-entropy losses with the mean absolute error and squared loss with the mean squared error. We then provide empirical evidence that algorithms minimizing these cross-entropy losses can outperform those based on the squared loss in linear reinforcement learning.

Paper Structure

This paper contains 21 sections, 8 theorems, 56 equations, 1 figure.

Key Result

Proposition 2.2

Let $Y$ be a bounded random variable with $Y \in [0,1]$ and mean $\mathbb{E}[Y] = \mu$. Then, for any $x \in [0,1]$, the expected log-loss satisfies:

Figures (1)

  • Figure 1: Failure rates for inverted pendulum as a function of the size of the batch dataset. Results are averaged over $45$ independently collected datasets, and fitted $Q$-iteration was run for $50$ iterations. We report $90\%$ confidence intervals via the shaded regions. The Left and Middle figures use Fourier features of order $2$, the Right figure uses Fourier features of order $3$. The Left figure uses $5$ uniformly spaced points as the support for the Cat, while the Middle and Right figures use $5$ non-uniformly spaced points as the support.

Theorems & Definitions (14)

  • Proposition 2.2
  • proof
  • Remark 2.3
  • Lemma 3.1
  • Proposition 3.2
  • Lemma 3.3
  • Proposition 4.1
  • proof
  • proof : Proof of \ref{['lem:log-loss-bound']}
  • proof : Proof of \ref{['lem:sq-mae-lb']}
  • ...and 4 more