Table of Contents
Fetching ...

Compositional meta-learning through probabilistic task inference

Jacob J. W. Bakermans, Pablo Tano, Reidar Riveland, Charles Findling, Alexandre Pouget

TL;DR

This work addresses meta-learning by enabling rapid adaptation to new tasks through compositional inference over learned task structure. It introduces a probabilistic generative model that separates within-module dynamics from between-module dynamics using a gating network $G_{m{ heta}}$ and modular RNNs ${M_{m{ }}^z}$, trained by maximizing the marginal likelihood $L = p(m{y}_{1:T};oldsymbol{ })$. For new tasks, solutions are obtained via inference (e.g., particle filtering) to identify the best module sequence ${z}_{1:T}$ without updating parameters, achieving one-shot acquisition in abstract rule learning and motor skills, even under sparse feedback. The framework bridges neural expressivity with probabilistic data efficiency, with potential extensions to continual learning, dynamic module growth, and more expressive task grammars.

Abstract

To solve a new task from minimal experience, it is essential to effectively reuse knowledge from previous tasks, a problem known as meta-learning. Compositional solutions, where common elements of computation are flexibly recombined into new configurations, are particularly well-suited for meta-learning. Here, we propose a compositional meta-learning model that explicitly represents tasks as structured combinations of reusable computations. We achieve this by learning a generative model that captures the underlying components and their statistics shared across a family of tasks. This approach transforms learning a new task into a probabilistic inference problem, which allows for finding solutions without parameter updates through highly constrained hypothesis testing. Our model successfully recovers ground truth components and statistics in rule learning and motor learning tasks. We then demonstrate its ability to quickly infer new solutions from just single examples. Together, our framework joins the expressivity of neural networks with the data-efficiency of probabilistic inference to achieve rapid compositional meta-learning.

Compositional meta-learning through probabilistic task inference

TL;DR

This work addresses meta-learning by enabling rapid adaptation to new tasks through compositional inference over learned task structure. It introduces a probabilistic generative model that separates within-module dynamics from between-module dynamics using a gating network and modular RNNs , trained by maximizing the marginal likelihood . For new tasks, solutions are obtained via inference (e.g., particle filtering) to identify the best module sequence without updating parameters, achieving one-shot acquisition in abstract rule learning and motor skills, even under sparse feedback. The framework bridges neural expressivity with probabilistic data efficiency, with potential extensions to continual learning, dynamic module growth, and more expressive task grammars.

Abstract

To solve a new task from minimal experience, it is essential to effectively reuse knowledge from previous tasks, a problem known as meta-learning. Compositional solutions, where common elements of computation are flexibly recombined into new configurations, are particularly well-suited for meta-learning. Here, we propose a compositional meta-learning model that explicitly represents tasks as structured combinations of reusable computations. We achieve this by learning a generative model that captures the underlying components and their statistics shared across a family of tasks. This approach transforms learning a new task into a probabilistic inference problem, which allows for finding solutions without parameter updates through highly constrained hypothesis testing. Our model successfully recovers ground truth components and statistics in rule learning and motor learning tasks. We then demonstrate its ability to quickly infer new solutions from just single examples. Together, our framework joins the expressivity of neural networks with the data-efficiency of probabilistic inference to achieve rapid compositional meta-learning.

Paper Structure

This paper contains 10 sections, 14 equations, 4 figures.

Figures (4)

  • Figure 1: Model overview. a) Model architecture. The model consists of a gating RNN that for a given gating hidden state ${\bm{g}}_{t-1}$, previously activated module ${\textnormal{z}}_{t-1}$, and input ${\mathbf{x}}_t$, parameterises a discrete probability distribution from which the currently active module RNN ${\textnormal{z}}_t$ is sampled. The selected module RNN processes the input ${\mathbf{x}}_t$ and module hidden state ${\mathbf{m}}_{t-1}$ to define an output distribution for ${\mathbf{y}}_t$. b) Graphical model. The model learns a probabilistic generative process with stochastic variables (circles) ${\textnormal{z}}_t$ and ${\mathbf{y}}_t$ that depend on input ${\mathbf{x}}_t$ and the deterministic model hidden states (diamonds) ${\bm{g}}_t$ and ${\bm{m}}_t$. Conceptually, this expands a HMM by replacing the transition and emission matrices by input-dependent RNNs. c) Particle filter schematic. To perform inference in this generative model, we define a particle system of $K$ particles (top row). At a given timestep, we sample module activations to calculate likelihoods $l^{(i)}$ of data ${\mathbf{y}}_t$ for each particle (second row). We resample particles from these (normalised) likelihoods (red arrows: particle $(1)$ is sampled twice, whereas particle $(2)$ is terminated) to reflect the module posterior $p({\textnormal{z}}_t|{\mathbf{y}}_{1:t})$ (third row), and continue this process for the next timestep (bottom row).
  • Figure 2: Rule learning. a) Learning curves. The training loss (negative log marginal likelihood) and task performance (mean squared error) decrease while the module and gating accuracy (correlation with ground truth operations and transitions) plateaus at 1 (grey lines: five individual seeds; black line: mean across seeds). b) Learned operations. Each of the six (columns) true shift operations (top row) shift their input entries by a consistent amount. This is shown by plotting in each matrix row $i$ the outcome of applying the shift to unit vector ${\bm{e}}^{(i)}$ as input. The six (columns) learned modules (bottom row) perform the same operation. c) Learned transitions. The true tasks incorporate a specific structure to operation sequences, shown by history-dependent transition matrices (top row). Each matrix row $i$ indicates the probability of the next module activation given that the previously selected module was module $i$ (first column), that the previous two selected modules were module $i$ (second column), previous three selected modules were module $i$ (third column), et cetera. The learned transitions (bottom row) reproduce the true transition pattern. d) Example test task. After training, the model infers a solution that accurately produces (second row) the desired output for a held-out task (top row). This test task’s sequence of shift operations (third row; each matrix row represents one shift operation and red dots indicate the true underlying operation sequence) is accurately recovered by the model’s module posterior (bottom row; matrix rows are module selection probabilities so that each column plots $p({\textnormal{z}}_t|{\mathbf{y}}_{1:t})$ with red dots showing the maximum a posteriori sequence $\text{argmax}_{{\textnormal{z}}_t}p({\textnormal{z}}_t|{\mathbf{y}}_{1:T})$. e) Sparse feedback example. The model still infers an accurate test task solution when feedback is provided in a small minority of timesteps (marked in yellow in the third row). f) Extended task example. Even on a test task that is four times longer than the training tasks, the model infers the correct solution, despite sparse feedback.
  • Figure 3: Control models. a) RNN control model. An RNN trained to perform the tasks with the same input as our model cannot learn either train or test tasks (grey dots: individual seeds, error bars s.e.m. across tasks; black bars: mean across seeds). b) RNN with task identity input. When the task identity is added to the inputs, the RNN performs well on the training tasks but cannot solve new test tasks without additional training. c) Task inference without gating network. As our model learns a generative model across tasks, it performs well on the training tasks and infers solutions to held-out test tasks from a single episode. However, without a gating network it performs poorly on tasks with sparse feedback. d) Full model. The full model learns training tasks and infers test tasks even if feedback is sparse.
  • Figure 4: Motor learning. a) Example motor tasks. Each task consists of a sequence of three motor skills, shown as chunks of different colours, starting from the star. b) Learned skills. After training, each module (output of one module in thin grey lines from dark to light in each subplot) learns to perform a skill (thick coloured lines; one true skill in each subplot). c) Learned transitions. The history-dependent transition matrices, analogous to Figure \ref{['fig:motor-learning']}c, show that the gating RNN learns to switch between skills depending on their true duration. d) Example test task. The model (thin solid line coloured by module, with white borders, and pre-feedback hypotheses in dotted lines) infers the new task trajectory (thick solid line, starting from star) from feedback at each step (grey circles) by selecting learned modules (bottom right, as in Figure \ref{['fig:motor-learning']}d) that execute the true skills (bottom left). e) Sparse feedback example. When feedback is sparse (at grey circle locations, and timesteps marked in yellow at bottom left) the model tests module hypotheses (dotted lines) that branch out at skill switch points until feedback confirms the current skill. This is reflected by the posterior $p({\textnormal{z}}_t|{\mathbf{y}}_{1:t})$ collapsing to a single module at feedback timesteps (bottom right, heatmap) and an accurate maximum a posteriori sequence $\text{argmax}_{{\textnormal{z}}_t}p({\textnormal{z}}_t|{\mathbf{y}}_{1:T})$ (bottom right, red dots).