Table of Contents
Fetching ...

Permutation Invariant Learning with High-Dimensional Particle Filters

Akhilan Boopathy, Aneesh Muppidi, Peggy Yang, Abhiram Iyer, William Yue, Ila Fiete

TL;DR

This work introduces a novel permutation-invariant learning framework based on high-dimensional particle filters, and develops an efficient particle filter for optimizing high-dimensional models, combining the strengths of Bayesian methods with gradient-based optimization.

Abstract

Sequential learning in deep models often suffers from challenges such as catastrophic forgetting and loss of plasticity, largely due to the permutation dependence of gradient-based algorithms, where the order of training data impacts the learning outcome. In this work, we introduce a novel permutation-invariant learning framework based on high-dimensional particle filters. We theoretically demonstrate that particle filters are invariant to the sequential ordering of training minibatches or tasks, offering a principled solution to mitigate catastrophic forgetting and loss-of-plasticity. We develop an efficient particle filter for optimizing high-dimensional models, combining the strengths of Bayesian methods with gradient-based optimization. Through extensive experiments on continual supervised and reinforcement learning benchmarks, including SplitMNIST, SplitCIFAR100, and ProcGen, we empirically show that our method consistently improves performance, while reducing variance compared to standard baselines.

Permutation Invariant Learning with High-Dimensional Particle Filters

TL;DR

This work introduces a novel permutation-invariant learning framework based on high-dimensional particle filters, and develops an efficient particle filter for optimizing high-dimensional models, combining the strengths of Bayesian methods with gradient-based optimization.

Abstract

Sequential learning in deep models often suffers from challenges such as catastrophic forgetting and loss of plasticity, largely due to the permutation dependence of gradient-based algorithms, where the order of training data impacts the learning outcome. In this work, we introduce a novel permutation-invariant learning framework based on high-dimensional particle filters. We theoretically demonstrate that particle filters are invariant to the sequential ordering of training minibatches or tasks, offering a principled solution to mitigate catastrophic forgetting and loss-of-plasticity. We develop an efficient particle filter for optimizing high-dimensional models, combining the strengths of Bayesian methods with gradient-based optimization. Through extensive experiments on continual supervised and reinforcement learning benchmarks, including SplitMNIST, SplitCIFAR100, and ProcGen, we empirically show that our method consistently improves performance, while reducing variance compared to standard baselines.

Paper Structure

This paper contains 29 sections, 3 theorems, 47 equations, 3 figures, 2 tables, 1 algorithm.

Key Result

Theorem 1

Suppose $\sigma_1, \sigma_2, ... \sigma_T$ is a permutation of $1, 2, ... T$ such that $N$ swaps of adjacent elements are required to convert $\sigma_1, \sigma_2, ... \sigma_T$ to $1, 2, ... T$. Denote the initialized particle filter as $\hat{p}_0$. Then,

Figures (3)

  • Figure 1: Illustration of how our particle filter converges to well-performing regions of the parameter space over the course of training on SplitMNIST. The plot is constructed by using tSNE to map the particles into two dimensions, then representing each particle with a unimodal Gaussian of fixed variance.
  • Figure 2: Average accuracy versus normalized task variance plots for both SplitCIFAR100 and SplitMNIST. The bottom right region of each plot represents the ideal scenario of high accuracy and low task-specific variance.
  • Figure 3: Mean episode reward curves for the lifelong setups of Starpilot, Fruitbot, and Dodgeball, comparing PPO and PPO with the weighted particle filter. The results indicate that PPO with the weighted particle filter has greater resistance to the loss of plasticity observed in standard PPO.

Theorems & Definitions (6)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • proof
  • proof
  • proof