In-Context Compositional Learning via Sparse Coding Transformer
Wei Chen, Jingxi Yu, Zichen Miao, Qiang Qiu
TL;DR
This paper tackles in-context compositional learning by rethinking the Transformer attention through sparse coding. It introduces two input-dependent dictionaries, an encoding $\phi(\cdot)$ and a decoding $\psi(\cdot)$, and enforces sparsity on the resulting coefficients $\boldsymbol{\alpha}$ to reveal and preserve compositional structure. Target task coefficients are estimated as a linear combination of context-task coefficients via a lifting-inspired scheme with learnable weights $\lambda_i$, enabling transfer of learned rules across tasks. Empirical results on S-RAVEN and RAVEN show that sparse-coding attention markedly improves compositional generalization and reconstruction quality, outperforming standard Transformers and showing robustness where dense attention fails. The approach also demonstrates potential benefits for language-model reasoning tasks and offers a parameter-efficient path to integrating structured inductive bias into pre-trained architectures, albeit with limitations in scaling to very large models.
Abstract
Transformer architectures have achieved remarkable success across language, vision, and multimodal tasks, and there is growing demand for them to address in-context compositional learning tasks. In these tasks, models solve the target problems by inferring compositional rules from context examples, which are composed of basic components structured by underlying rules. However, some of these tasks remain challenging for Transformers, which are not inherently designed to handle compositional tasks and offer limited structural inductive bias. In this work, inspired by the principle of sparse coding, we propose a reformulation of the attention to enhance its capability for compositional tasks. In sparse coding, data are represented as sparse combinations of dictionary atoms with coefficients that capture their compositional rules. Specifically, we reinterpret the attention block as a mapping of inputs into outputs through projections onto two sets of learned dictionary atoms: an encoding dictionary and a decoding dictionary. The encoding dictionary decomposes the input into a set of coefficients, which represent the compositional structure of the input. To enhance structured representations, we impose sparsity on these coefficients. The sparse coefficients are then used to linearly combine the decoding dictionary atoms to generate the output. Furthermore, to assist compositional generalization tasks, we propose estimating the coefficients of the target problem as a linear combination of the coefficients obtained from the context examples. We demonstrate the effectiveness of our approach on the S-RAVEN and RAVEN datasets. For certain compositional generalization tasks, our method maintains performance even when standard Transformers fail, owing to its ability to learn and apply compositional rules.
