Table of Contents
Fetching ...

Learning Compositional Functions with Transformers from Easy-to-Hard Data

Zixuan Wang, Eshaan Nichani, Alberto Bietti, Alex Damian, Daniel Hsu, Jason D. Lee, Denny Wu

TL;DR

<p>Problem</p> The paper tackles the learnability of the k-fold composition task for transformers, a synthetic yet representative compositional reasoning problem that couples contextual and parametric knowledge. <p>Approach</p> It establishes a constructive upper bound via an $O( ext{log } k)$-deep transformer with embedding $d= ilde{O}(Nk)$, proves a Statistical Query (SQ) lower bound showing exponential sample complexity in k, and devises gradient-based learning strategies—curriculum and data mixture—that yield polynomial sample and runtime in k and N. <p>Contributions</p> The work demonstrates a statistical–computational gap under SQ while providing practical training schemes that enable efficient learning of deep transformer-based compositional tasks; it also discusses embedding-size tradeoffs and generalization concerns (cyclic variants, non-power-of-two k). <p>Impact</p> These results clarify when and how transformers can learn complex compositional tasks and highlight the essential role of easy-hard data mixtures and curricula in enabling efficient gradient-based learning, informing both theory and practice in scalable reasoning models.

Abstract

Transformer-based language models have demonstrated impressive capabilities across a range of complex reasoning tasks. Prior theoretical work exploring the expressive power of transformers has shown that they can efficiently perform multi-step reasoning tasks involving parallelizable computations. However, the learnability of such constructions, particularly the conditions on the data distribution that enable efficient learning via gradient-based optimization, remains an open question. Towards answering this question, in this work we study the learnability of the $k$-fold composition task, which requires computing an interleaved composition of $k$ input permutations and $k$ hidden permutations, and can be expressed by a transformer with $O(\log k)$ layers. On the negative front, we prove a Statistical Query (SQ) lower bound showing that any SQ learner that makes only polynomially-many queries to an SQ oracle for the $k$-fold composition task distribution must have sample size exponential in $k$, thus establishing a statistical-computational gap. On the other hand, we show that this function class can be efficiently learned, with runtime and sample complexity polynomial in $k$, by gradient descent on an $O(\log k)$-depth transformer via two different curriculum learning strategies: one in which data consists of $k'$-fold composition functions with $k' \le k$ presented in increasing difficulty, and another in which all such data is presented simultaneously. Our work sheds light on the necessity and sufficiency of having both easy and hard examples in the data distribution for transformers to learn complex compositional tasks.

Learning Compositional Functions with Transformers from Easy-to-Hard Data

TL;DR

<p>Problem</p> The paper tackles the learnability of the k-fold composition task for transformers, a synthetic yet representative compositional reasoning problem that couples contextual and parametric knowledge. <p>Approach</p> It establishes a constructive upper bound via an -deep transformer with embedding , proves a Statistical Query (SQ) lower bound showing exponential sample complexity in k, and devises gradient-based learning strategies—curriculum and data mixture—that yield polynomial sample and runtime in k and N. <p>Contributions</p> The work demonstrates a statistical–computational gap under SQ while providing practical training schemes that enable efficient learning of deep transformer-based compositional tasks; it also discusses embedding-size tradeoffs and generalization concerns (cyclic variants, non-power-of-two k). <p>Impact</p> These results clarify when and how transformers can learn complex compositional tasks and highlight the essential role of easy-hard data mixtures and curricula in enabling efficient gradient-based learning, informing both theory and practice in scalable reasoning models.

Abstract

Transformer-based language models have demonstrated impressive capabilities across a range of complex reasoning tasks. Prior theoretical work exploring the expressive power of transformers has shown that they can efficiently perform multi-step reasoning tasks involving parallelizable computations. However, the learnability of such constructions, particularly the conditions on the data distribution that enable efficient learning via gradient-based optimization, remains an open question. Towards answering this question, in this work we study the learnability of the -fold composition task, which requires computing an interleaved composition of input permutations and hidden permutations, and can be expressed by a transformer with layers. On the negative front, we prove a Statistical Query (SQ) lower bound showing that any SQ learner that makes only polynomially-many queries to an SQ oracle for the -fold composition task distribution must have sample size exponential in , thus establishing a statistical-computational gap. On the other hand, we show that this function class can be efficiently learned, with runtime and sample complexity polynomial in , by gradient descent on an -depth transformer via two different curriculum learning strategies: one in which data consists of -fold composition functions with presented in increasing difficulty, and another in which all such data is presented simultaneously. Our work sheds light on the necessity and sufficiency of having both easy and hard examples in the data distribution for transformers to learn complex compositional tasks.

Paper Structure

This paper contains 53 sections, 33 theorems, 361 equations, 3 figures, 2 algorithms.

Key Result

Theorem 1

Assume that $k$ is a power of two. There exists an embedding function $\phi$ with $d = kN(3 + \log_2k)$ such that, for any $\pi \in (S_N)^k$, there exists an $L = \log_2 k + 1$ layer transformer which can exactly express the $k$-fold composition task, i.e

Figures (3)

  • Figure 1: $k$-fold composition task -- red arrows represent the hidden permutation $\pi_i$ and green arrows denote input permutations $\sigma_i$. Given an input $(\sigma,x)$, $f_\pi(\cdot,\cdot)$ composes $2k$ permutations to output $f_\pi(\sigma,x)$.
  • Figure 2: Illustration the format of input $X^{(0)}$ and the attention pattern in \ref{['thm:construction']}.
  • Figure 3: Left: Curriculum learning (Algorithm \ref{['alg:training_alg']}). Middle: Learning with data mixture (Algorithm \ref{['alg:training_alg_mix']}). Right: Comparison between training with and without mixed data on standard encoder transformer.

Theorems & Definitions (69)

  • Remark 1
  • Definition 1: Self-attention head
  • Definition 2: Attention-only transformer
  • Theorem 1
  • Theorem 2
  • Remark 2
  • Theorem 3: Guarantee for \ref{['alg:training_alg']}
  • Theorem 4: Guarantee for mixed data training
  • proof : Proof of \ref{['thm:construction']}
  • Theorem 5
  • ...and 59 more