Iterative Amortized Inference: Unifying In-Context Learning and Learned Optimizers
Sarthak Mittal, Divyat Mahajan, Guillaume Lajoie, Mohammad Pezeshki
TL;DR
The paper introduces a unified framework for amortized learning across tasks by separating task adaptation (via $g_{\boldsymbol{\varphi}}$) from prediction (via $f_\gamma$), and then proposes Iterative Amortized Inference to refine task-specific representations over mini-batches. It defines a taxonomy of amortization into parametric, implicit, and explicit regimes, analyzes their trade-offs, and connects them through a stochastic optimization lens. The main contribution is a scalable, multi-step refinement approach that improves generalization and efficiency across ID and OoD tasks, demonstrated on diverse predictive and generative benchmarks. This framework provides a principled path to general-purpose task adaptation in large-scale models like LLMs and related architectures, with potential for further extensions beyond greedy greedy refinement.
Abstract
Modern learning systems increasingly rely on amortized learning - the idea of reusing computation or inductive biases shared across tasks to enable rapid generalization to novel problems. This principle spans a range of approaches, including meta-learning, in-context learning, prompt tuning, learned optimizers and more. While motivated by similar goals, these approaches differ in how they encode and leverage task-specific information, often provided as in-context examples. In this work, we propose a unified framework which describes how such methods differ primarily in the aspects of learning they amortize - such as initializations, learned updates, or predictive mappings - and how they incorporate task data at inference. We introduce a taxonomy that categorizes amortized models into parametric, implicit, and explicit regimes, based on whether task adaptation is externalized, internalized, or jointly modeled. Building on this view, we identify a key limitation in current approaches: most methods struggle to scale to large datasets because their capacity to process task data at inference (e.g., context length) is often limited. To address this, we propose iterative amortized inference, a class of models that refine solutions step-by-step over mini-batches, drawing inspiration from stochastic optimization. Our formulation bridges optimization-based meta-learning with forward-pass amortization in models like LLMs, offering a scalable and extensible foundation for general-purpose task adaptation.
