Gradient-based inference of abstract task representations for generalization in neural networks
Ali Hummos, Felipe del Río, Brabeeba Mien Wang, Julio Hurtado, Cristian B. Calderon, Guangyu Robert Yang
TL;DR
This paper introduces Gradient-Based Inference (GBI) to learn and reuse abstract task representations in neural networks. By treating task abstractions as latent variables $oldsymbol{Z}$ and grounding inference in variational/EM-inspired principles, GBI enables efficient test-time inference via backpropagated gradients and optional iterative refinement to recombine abstractions for novel tasks. Across a toy Bayesian dataset, image classification, and language modeling, GBI improves data efficiency, generalization to unseen tasks, and reduces forgetting, while providing useful uncertainty estimates and robust OOD detection. The approach demonstrates that separating computation from task abstraction can yield practical advantages for rapid adaptation and reliable learning in domain-general settings. These results motivate future work on unsupervised discovery of multidimensional task representations and scaling to larger, real-world problems.
Abstract
Humans and many animals show remarkably adaptive behavior and can respond differently to the same input depending on their internal goals. The brain not only represents the intermediate abstractions needed to perform a computation but also actively maintains a representation of the computation itself (task abstraction). Such separation of the computation and its abstraction is associated with faster learning, flexible decision-making, and broad generalization capacity. We investigate if such benefits might extend to neural networks trained with task abstractions. For such benefits to emerge, one needs a task inference mechanism that possesses two crucial abilities: First, the ability to infer abstract task representations when no longer explicitly provided (task inference), and second, manipulate task representations to adapt to novel problems (task recomposition). To tackle this, we cast task inference as an optimization problem from a variational inference perspective and ground our approach in an expectation-maximization framework. We show that gradients backpropagated through a neural network to a task representation layer are an efficient heuristic to infer current task demands, a process we refer to as gradient-based inference (GBI). Further iterative optimization of the task representation layer allows for recomposing abstractions to adapt to novel situations. Using a toy example, a novel image classifier, and a language model, we demonstrate that GBI provides higher learning efficiency and generalization to novel tasks and limits forgetting. Moreover, we show that GBI has unique advantages such as preserving information for uncertainty estimation and detecting out-of-distribution samples.
