Table of Contents
Fetching ...

Iterative Amortized Inference: Unifying In-Context Learning and Learned Optimizers

Sarthak Mittal, Divyat Mahajan, Guillaume Lajoie, Mohammad Pezeshki

TL;DR

The paper introduces a unified framework for amortized learning across tasks by separating task adaptation (via $g_{\boldsymbol{\varphi}}$) from prediction (via $f_\gamma$), and then proposes Iterative Amortized Inference to refine task-specific representations over mini-batches. It defines a taxonomy of amortization into parametric, implicit, and explicit regimes, analyzes their trade-offs, and connects them through a stochastic optimization lens. The main contribution is a scalable, multi-step refinement approach that improves generalization and efficiency across ID and OoD tasks, demonstrated on diverse predictive and generative benchmarks. This framework provides a principled path to general-purpose task adaptation in large-scale models like LLMs and related architectures, with potential for further extensions beyond greedy greedy refinement.

Abstract

Modern learning systems increasingly rely on amortized learning - the idea of reusing computation or inductive biases shared across tasks to enable rapid generalization to novel problems. This principle spans a range of approaches, including meta-learning, in-context learning, prompt tuning, learned optimizers and more. While motivated by similar goals, these approaches differ in how they encode and leverage task-specific information, often provided as in-context examples. In this work, we propose a unified framework which describes how such methods differ primarily in the aspects of learning they amortize - such as initializations, learned updates, or predictive mappings - and how they incorporate task data at inference. We introduce a taxonomy that categorizes amortized models into parametric, implicit, and explicit regimes, based on whether task adaptation is externalized, internalized, or jointly modeled. Building on this view, we identify a key limitation in current approaches: most methods struggle to scale to large datasets because their capacity to process task data at inference (e.g., context length) is often limited. To address this, we propose iterative amortized inference, a class of models that refine solutions step-by-step over mini-batches, drawing inspiration from stochastic optimization. Our formulation bridges optimization-based meta-learning with forward-pass amortization in models like LLMs, offering a scalable and extensible foundation for general-purpose task adaptation.

Iterative Amortized Inference: Unifying In-Context Learning and Learned Optimizers

TL;DR

The paper introduces a unified framework for amortized learning across tasks by separating task adaptation (via ) from prediction (via ), and then proposes Iterative Amortized Inference to refine task-specific representations over mini-batches. It defines a taxonomy of amortization into parametric, implicit, and explicit regimes, analyzes their trade-offs, and connects them through a stochastic optimization lens. The main contribution is a scalable, multi-step refinement approach that improves generalization and efficiency across ID and OoD tasks, demonstrated on diverse predictive and generative benchmarks. This framework provides a principled path to general-purpose task adaptation in large-scale models like LLMs and related architectures, with potential for further extensions beyond greedy greedy refinement.

Abstract

Modern learning systems increasingly rely on amortized learning - the idea of reusing computation or inductive biases shared across tasks to enable rapid generalization to novel problems. This principle spans a range of approaches, including meta-learning, in-context learning, prompt tuning, learned optimizers and more. While motivated by similar goals, these approaches differ in how they encode and leverage task-specific information, often provided as in-context examples. In this work, we propose a unified framework which describes how such methods differ primarily in the aspects of learning they amortize - such as initializations, learned updates, or predictive mappings - and how they incorporate task data at inference. We introduce a taxonomy that categorizes amortized models into parametric, implicit, and explicit regimes, based on whether task adaptation is externalized, internalized, or jointly modeled. Building on this view, we identify a key limitation in current approaches: most methods struggle to scale to large datasets because their capacity to process task data at inference (e.g., context length) is often limited. To address this, we propose iterative amortized inference, a class of models that refine solutions step-by-step over mini-batches, drawing inspiration from stochastic optimization. Our formulation bridges optimization-based meta-learning with forward-pass amortization in models like LLMs, offering a scalable and extensible foundation for general-purpose task adaptation.

Paper Structure

This paper contains 36 sections, 20 equations, 15 figures, 6 tables.

Figures (15)

  • Figure 1: Iterative Amortized Inference for parametric, explicit and implicit parameterizations. While the parametric and explicit setup provide task specific information ${\bm{\theta}}$ invariant to the query, the implicit setup instead continually refines a prediction tied to the query.
  • Figure 2: Samples generated from the implicit generative model for GMM and Alphabets task.
  • Figure 3: We analyze the benefits, or lack thereof, of leveraging multiple past states and gradients for parametric amortization, when we use both observations and gradients as conditional inputs.
  • Figure 4: Our ablations reveal that carrying over logits as the recurrent state across iterations in implicit models outperforms other state representations, with softmax outputs performing comparably.
  • Figure 5: Masking procedure for the causally masked parametric and explicit model, where the context evolves according to a causal mask $-$ the matrix of black and white squares describes the masking procedure where white blocks denote masks in the matrix $-$ where each token in parallel consequently predicts parameters of interest conditioned on past batch of data and previous state. This mimics having variable sized dataset and processing for variable dataset sizes in parallel. The output corresponding to the last token is the ${\bm{\theta}}_{\mathcal{T}}^{(t)}$ that gets fed back recurrently.
  • ...and 10 more figures