Table of Contents
Fetching ...

Gradient-based inference of abstract task representations for generalization in neural networks

Ali Hummos, Felipe del Río, Brabeeba Mien Wang, Julio Hurtado, Cristian B. Calderon, Guangyu Robert Yang

TL;DR

This paper introduces Gradient-Based Inference (GBI) to learn and reuse abstract task representations in neural networks. By treating task abstractions as latent variables $oldsymbol{Z}$ and grounding inference in variational/EM-inspired principles, GBI enables efficient test-time inference via backpropagated gradients and optional iterative refinement to recombine abstractions for novel tasks. Across a toy Bayesian dataset, image classification, and language modeling, GBI improves data efficiency, generalization to unseen tasks, and reduces forgetting, while providing useful uncertainty estimates and robust OOD detection. The approach demonstrates that separating computation from task abstraction can yield practical advantages for rapid adaptation and reliable learning in domain-general settings. These results motivate future work on unsupervised discovery of multidimensional task representations and scaling to larger, real-world problems.

Abstract

Humans and many animals show remarkably adaptive behavior and can respond differently to the same input depending on their internal goals. The brain not only represents the intermediate abstractions needed to perform a computation but also actively maintains a representation of the computation itself (task abstraction). Such separation of the computation and its abstraction is associated with faster learning, flexible decision-making, and broad generalization capacity. We investigate if such benefits might extend to neural networks trained with task abstractions. For such benefits to emerge, one needs a task inference mechanism that possesses two crucial abilities: First, the ability to infer abstract task representations when no longer explicitly provided (task inference), and second, manipulate task representations to adapt to novel problems (task recomposition). To tackle this, we cast task inference as an optimization problem from a variational inference perspective and ground our approach in an expectation-maximization framework. We show that gradients backpropagated through a neural network to a task representation layer are an efficient heuristic to infer current task demands, a process we refer to as gradient-based inference (GBI). Further iterative optimization of the task representation layer allows for recomposing abstractions to adapt to novel situations. Using a toy example, a novel image classifier, and a language model, we demonstrate that GBI provides higher learning efficiency and generalization to novel tasks and limits forgetting. Moreover, we show that GBI has unique advantages such as preserving information for uncertainty estimation and detecting out-of-distribution samples.

Gradient-based inference of abstract task representations for generalization in neural networks

TL;DR

This paper introduces Gradient-Based Inference (GBI) to learn and reuse abstract task representations in neural networks. By treating task abstractions as latent variables and grounding inference in variational/EM-inspired principles, GBI enables efficient test-time inference via backpropagated gradients and optional iterative refinement to recombine abstractions for novel tasks. Across a toy Bayesian dataset, image classification, and language modeling, GBI improves data efficiency, generalization to unseen tasks, and reduces forgetting, while providing useful uncertainty estimates and robust OOD detection. The approach demonstrates that separating computation from task abstraction can yield practical advantages for rapid adaptation and reliable learning in domain-general settings. These results motivate future work on unsupervised discovery of multidimensional task representations and scaling to larger, real-world problems.

Abstract

Humans and many animals show remarkably adaptive behavior and can respond differently to the same input depending on their internal goals. The brain not only represents the intermediate abstractions needed to perform a computation but also actively maintains a representation of the computation itself (task abstraction). Such separation of the computation and its abstraction is associated with faster learning, flexible decision-making, and broad generalization capacity. We investigate if such benefits might extend to neural networks trained with task abstractions. For such benefits to emerge, one needs a task inference mechanism that possesses two crucial abilities: First, the ability to infer abstract task representations when no longer explicitly provided (task inference), and second, manipulate task representations to adapt to novel problems (task recomposition). To tackle this, we cast task inference as an optimization problem from a variational inference perspective and ground our approach in an expectation-maximization framework. We show that gradients backpropagated through a neural network to a task representation layer are an efficient heuristic to infer current task demands, a process we refer to as gradient-based inference (GBI). Further iterative optimization of the task representation layer allows for recomposing abstractions to adapt to novel situations. Using a toy example, a novel image classifier, and a language model, we demonstrate that GBI provides higher learning efficiency and generalization to novel tasks and limits forgetting. Moreover, we show that GBI has unique advantages such as preserving information for uncertainty estimation and detecting out-of-distribution samples.
Paper Structure (27 sections, 11 equations, 15 figures, 8 tables)

This paper contains 27 sections, 11 equations, 15 figures, 8 tables.

Figures (15)

  • Figure 1: Sequence prediction task in one dimension. A) Bayesian graphical model generating the data. Task node (z) flips periodically and sets the mean of the Gaussian generating the data between $\mu_1$ = 0.2 or $\mu_2$ = 0.8, with fixed standard deviation $\sigma$ at 0.1. B) Sample generated data observations. C) The likelihood of each context node value given data is calculated using the Gaussian probability density equation, more details in Appendix B. D) The posterior over z values.
  • Figure 2: Comparing the training loss of an LSTM trained with task input (GBI-LSTM) without task input (LSTM). Note the variability in these loss estimates despite the simulation run over 10 random seeds is due to the stochastic block transitions still coinciding frequently. B) We show the mean loss per block but exclude the initial 20 predictions in each block to exclude poor performance from either model limited to transitioning between tasks. C) We quantify the ratio of shared neurons active in the 0.2 blocks and the 0.8 blocks using 'task variance' to identify engaged neurons. We binned training blocks into groups of 5 to aggregate the data and increase accuracy.
  • Figure 3: Neural network trained with task abstractions can dynamically tune itself to incoming data and generalizes to novel data points. A) Responses of the GBI-LSTM to sample data and B) the gradients of the prediction loss w.r.t. to the two Z units. C) Iteratively optimized Z values accumulating evidence dynamically. D) As a generalization test, while the network was trained on data from Gaussians with means 0.2 and 0.8, we test its responses to Gaussians between -0.2 to 1.2 (in 0.1 increments). We record the mean MSE values across 20 runs with different random seeds and compare to the baseline LSTM trained without task abstraction input. We only show runs where the last block of training was a 0.2 block, to highlight models behavior on that mean vs the other (0.8). E) The mean of the Gaussian data vs. the mean network predictions in each generalization block. These results are further quantified in Table \ref{['table:toy_task_quantification_of_losses']}. F-H) We show the other mode of using the same GBI-LSTM for gradient-based inference. We fix the task abstraction input Z to its state of maximum entropy (here [0.5, 0.5]) and take gradients from that point. While the network responds in the middle (F), the gradients w.r.t to Z (G) or similarly, the one-step gradient descent Z values behave much like the likelihood function, c.f. values computed using the Gaussian PDF in Fig \ref{['fig:toy_task']}C.
  • Figure 4: Robust out-of-distribution (OOD) samples detection in a GBI autoencoder compared to a classifier. Both models were trained on MNIST digits, and tested on OOD dataset fashionMNIST. We take the A) max value of the task abstraction layer in GBI and for the B) logits in the classifier and consider how are they distributed for in-distribution samples and OOD samples. C) ROC curves show that GBI max values are more discriminative.
  • Figure 5: Better data-efficiency in an LSTM trained with task abstraction input (dataset ID one-hot) (GBI-LSTM) compared to training without such input (LSTM). Cross entropy loss for GBI-LSTM decreases faster for GBI-LSTM on the A) training set and B) testing set. Note that we train with a word-level tokenizer with a larger output dimension which has a highe cross entropy loss upper bound ( 12). Each simulation run had three language datasets picked randomly from 10 datasets. 10 data set picks, each run for 3 random seeds, shaded area SEM. C) One-pass gradient-based inference can infer dataset ID accurately at test time needing only a few words. GBI-LSTM takes a sequence of words available and outputs predictions over the same sequence shifted by one token (sequence to sequence modeling). We take the gradients w.r.t the task abstraction inputs and use the gradients to infer dataset ID. We show examples from 3 datasets to show variability. D) Iterative optimization can be used to adapt a GBI-LSTM to a novel dataset. We compare it to optimizing the inputs of an LSTM that was trained with no task provided to rule out any generic optimization effects. Due to the variability of loss values across datasets, for each dataset, we take the mean and SEM over 4 seeds and then aggregate the results across datasets.
  • ...and 10 more figures