Table of Contents
Fetching ...

Universal In-Context Approximation By Prompting Fully Recurrent Models

Aleksandar Petrov, Tom A. Lamb, Alasdair Paren, Philip H. S. Torr, Adel Bibi

TL;DR

It is demonstrated that RNNs, LSTMs, GRUs, Linear RNNs, and linear gated architectures such as Mamba and Hawk/Griffin can also serve as universal in-context approximators, and a programming language called LSRL is introduced that compiles to these fully recurrent architectures.

Abstract

Zero-shot and in-context learning enable solving tasks without model fine-tuning, making them essential for developing generative model solutions. Therefore, it is crucial to understand whether a pretrained model can be prompted to approximate any function, i.e., whether it is a universal in-context approximator. While it was recently shown that transformer models do possess this property, these results rely on their attention mechanism. Hence, these findings do not apply to fully recurrent architectures like RNNs, LSTMs, and the increasingly popular SSMs. We demonstrate that RNNs, LSTMs, GRUs, Linear RNNs, and linear gated architectures such as Mamba and Hawk/Griffin can also serve as universal in-context approximators. To streamline our argument, we introduce a programming language called LSRL that compiles to these fully recurrent architectures. LSRL may be of independent interest for further studies of fully recurrent models, such as constructing interpretability benchmarks. We also study the role of multiplicative gating and observe that architectures incorporating such gating (e.g., LSTMs, GRUs, Hawk/Griffin) can implement certain operations more stably, making them more viable candidates for practical in-context universal approximation.

Universal In-Context Approximation By Prompting Fully Recurrent Models

TL;DR

It is demonstrated that RNNs, LSTMs, GRUs, Linear RNNs, and linear gated architectures such as Mamba and Hawk/Griffin can also serve as universal in-context approximators, and a programming language called LSRL is introduced that compiles to these fully recurrent architectures.

Abstract

Zero-shot and in-context learning enable solving tasks without model fine-tuning, making them essential for developing generative model solutions. Therefore, it is crucial to understand whether a pretrained model can be prompted to approximate any function, i.e., whether it is a universal in-context approximator. While it was recently shown that transformer models do possess this property, these results rely on their attention mechanism. Hence, these findings do not apply to fully recurrent architectures like RNNs, LSTMs, and the increasingly popular SSMs. We demonstrate that RNNs, LSTMs, GRUs, Linear RNNs, and linear gated architectures such as Mamba and Hawk/Griffin can also serve as universal in-context approximators. To streamline our argument, we introduce a programming language called LSRL that compiles to these fully recurrent architectures. LSRL may be of independent interest for further studies of fully recurrent models, such as constructing interpretability benchmarks. We also study the role of multiplicative gating and observe that architectures incorporating such gating (e.g., LSTMs, GRUs, Hawk/Griffin) can implement certain operations more stably, making them more viable candidates for practical in-context universal approximation.
Paper Structure (69 sections, 48 equations, 4 figures)

This paper contains 69 sections, 48 equations, 4 figures.

Figures (4)

  • Figure 1: Compilation of an LSRL program to a Linear RNN. An example of a simple LSRL program that takes a sequence of 0s and 1s as an input and outputs 1 if there have been more 1s than 0s and 0 otherwise. The LSRL compiler follows the rules in \ref{['sec:debranching_rules']} to simplify the computation DAG into a path graph. The resulting path graph can be represented as a Linear RNN with one layer.
  • Figure 2: Intuition behind the LSRL program for universal in-context approximation for continuous functions in \ref{['lst:continous_approximation']}. Our target function $f$ has input dimension $d_\text{in}=2$ and output dimension $d_\text{out}=1$. Each input dimension is split into two parts, hence $\delta=1/2$. We illustrated an example input sequence of length 5: one for the query and four for the prompt tokens corresponding to each of the discretisation cells. The query $(q_1, q_2)$ falls in the cell corresponding to the third prompt token. We show how the two LinState variables in the program are updated after each step. Most notably, how the state holding the output $\texttt{y}$ is updated after $\bm p_3$ is processed.
  • Figure 3: Intuition behind the LSRL program for universal in-context approximation for discrete functions in \ref{['lst:tok2tok_approximation']}. Our keys and values have length $n{=}3$ and represent countries and capitals, e.g., AUStria$\mapsto$VIEnna, BULgaria$\mapsto$SOFia, and so on. The query is CAN for Canada and the final $n$ outputs are OTT (Ottawa). We show the values of some of the variables in \ref{['lst:tok2tok_approximation']} at each step, with the LinState variables being marked with arrows. For cleaner presentation we are tokenizing letters as 0$\mapsto$?, 1$\mapsto$A, 2$\mapsto$B, etc. Vertical separators are for illustration purposes only.
  • Figure 4: Robustness of the various f_ifelse implementations to model parameter noise. We show how the performance of the two universal approximation programs in \ref{['lst:continous_approximation', 'lst:tok2tok_approximation']} deteriorates as we add Gaussian noise of various magnitudes to the non-zero weights of the resulting compiled models. As expected, the original f_ifelse implementation in \ref{['eq:ifelse_direct']} exhibits numerical precision errors at the lowest noise magnitude. For the token sequence case, numerical precision errors are present in all samples even in the no-noise setting. Hence, the original f_ifelse implementation is less numerically robust while the implementations with multiplicative gating are the most robust. For \ref{['lst:continous_approximation']} (approximating $\mathcal{C}^\text{vec}$) we report the Euclidean distance between the target function value and the estimated one over 10 queries for 25 target functions. For \ref{['lst:tok2tok_approximation']} we report the percentage of wrong token predictions over 5 queries for 25 dictionary maps. Lower values are better in both cases.