Table of Contents
Fetching ...

PixelBrax: Learning Continuous Control from Pixels End-to-End on the GPU

Trevor McInroe, Samuel Garcin

TL;DR

PixelBrax tackles the bottleneck of on-GPU rendering in pixel-based reinforcement learning by providing end-to-end GPU execution for continuous-control tasks. It combines Brax with a pure JAX renderer (jaxrenderer) to render pixel observations entirely on the GPU, enabling thousands of parallel environments and reproducible experiments through explicit pseudorandom number generation. The approach supports color and video distractors for generalization benchmarking and reports substantial throughput gains—up to two orders of magnitude faster than CPU-based rendering—while enabling training with standard RL algorithms on pixel inputs. This work lowers the barrier to scalable, reproducible pixel-RL research and provides open-source tooling for researchers to reproduce and extend GPU-based pixel benchmarks.

Abstract

We present PixelBrax, a set of continuous control tasks with pixel observations. We combine the Brax physics engine with a pure JAX renderer, allowing reinforcement learning (RL) experiments to run end-to-end on the GPU. PixelBrax can render observations over thousands of parallel environments and can run two orders of magnitude faster than existing benchmarks that rely on CPU-based rendering. Additionally, PixelBrax supports fully reproducible experiments through its explicit handling of any stochasticity within the environments and supports color and video distractors for benchmarking generalization. We open-source PixelBrax alongside JAX implementations of several RL algorithms at github.com/trevormcinroe/pixelbrax.

PixelBrax: Learning Continuous Control from Pixels End-to-End on the GPU

TL;DR

PixelBrax tackles the bottleneck of on-GPU rendering in pixel-based reinforcement learning by providing end-to-end GPU execution for continuous-control tasks. It combines Brax with a pure JAX renderer (jaxrenderer) to render pixel observations entirely on the GPU, enabling thousands of parallel environments and reproducible experiments through explicit pseudorandom number generation. The approach supports color and video distractors for generalization benchmarking and reports substantial throughput gains—up to two orders of magnitude faster than CPU-based rendering—while enabling training with standard RL algorithms on pixel inputs. This work lowers the barrier to scalable, reproducible pixel-RL research and provides open-source tooling for researchers to reproduce and extend GPU-based pixel benchmarks.

Abstract

We present PixelBrax, a set of continuous control tasks with pixel observations. We combine the Brax physics engine with a pure JAX renderer, allowing reinforcement learning (RL) experiments to run end-to-end on the GPU. PixelBrax can render observations over thousands of parallel environments and can run two orders of magnitude faster than existing benchmarks that rely on CPU-based rendering. Additionally, PixelBrax supports fully reproducible experiments through its explicit handling of any stochasticity within the environments and supports color and video distractors for benchmarking generalization. We open-source PixelBrax alongside JAX implementations of several RL algorithms at github.com/trevormcinroe/pixelbrax.

Paper Structure

This paper contains 2 sections, 3 figures.

Table of Contents

  1. Introduction
  2. PixelBrax

Figures (3)

  • Figure 1: An initial state in the Humanoid$\,$ environment with no distractors, color distractors, and video distractors (left-to-right). When using color distractors, the color adjustment to each frame is different at each timestep. Likewise, when using video distractors, each subsequent timesteps' pixels are overlayed with the subsequent frame in the video.
  • Figure 2: Steps per second for all four currently-supported PixelBrax environments (non-hashed) and three DMC from pixels environments (hashed) over $\{ 1, 10, 100, 1000 \}$ parallel environments.
  • Figure 3: Example returns for PPO, PPG, and DCPG in the HalfCheetah$\,$ (top-left), Walker2d$\,$ (top-right), Ant$\,$ (bottom-left), and Humanoid$\,$ (bottom-right), all with video distractors. Bold line represent mean, shaded area represents $\pm$ one standard deviation.