Table of Contents
Fetching ...

Nonparametric Teaching of Attention Learners

Chen Zhang, Jianghui Wang, Bingyang Cheng, Zhongtao Chen, Wendong XU, Cong Wang, Marco Canini, Francesco Orabona, Yik Chung WU, Ngai Wong

TL;DR

It is shown for the first time that teaching attention learners is consistent with teaching importance-adaptive nonparametric learners and performance is consistently preserved and often enhanced across a diverse set of downstream tasks.

Abstract

Attention learners, neural networks built on the attention mechanism, e.g., transformers, excel at learning the implicit relationships that relate sequences to their corresponding properties, e.g., mapping a given sequence of tokens to the probability of the next token. However, the learning process tends to be costly. To address this, we present a novel paradigm named Attention Neural Teaching (AtteNT) that reinterprets the learning process through a nonparametric teaching perspective. Specifically, the latter provides a theoretical framework for teaching mappings that are implicitly defined (i.e., nonparametric) via example selection. Such an implicit mapping is embodied through a dense set of sequence-property pairs, with the AtteNT teacher selecting a subset to accelerate convergence in attention learner training. By analytically investigating the role of attention on parameter-based gradient descent during training, and recasting the evolution of attention learners, shaped by parameter updates, through functional gradient descent in nonparametric teaching, we show for the first time that teaching attention learners is consistent with teaching importance-adaptive nonparametric learners. These new findings readily commit AtteNT to enhancing learning efficiency of attention learners. Specifically, we observe training time reductions of 13.01% for LLMs and 20.58% for ViTs, spanning both fine-tuning and training-from-scratch regimes. Crucially, these gains are achieved without compromising accuracy; in fact, performance is consistently preserved and often enhanced across a diverse set of downstream tasks.

Nonparametric Teaching of Attention Learners

TL;DR

It is shown for the first time that teaching attention learners is consistent with teaching importance-adaptive nonparametric learners and performance is consistently preserved and often enhanced across a diverse set of downstream tasks.

Abstract

Attention learners, neural networks built on the attention mechanism, e.g., transformers, excel at learning the implicit relationships that relate sequences to their corresponding properties, e.g., mapping a given sequence of tokens to the probability of the next token. However, the learning process tends to be costly. To address this, we present a novel paradigm named Attention Neural Teaching (AtteNT) that reinterprets the learning process through a nonparametric teaching perspective. Specifically, the latter provides a theoretical framework for teaching mappings that are implicitly defined (i.e., nonparametric) via example selection. Such an implicit mapping is embodied through a dense set of sequence-property pairs, with the AtteNT teacher selecting a subset to accelerate convergence in attention learner training. By analytically investigating the role of attention on parameter-based gradient descent during training, and recasting the evolution of attention learners, shaped by parameter updates, through functional gradient descent in nonparametric teaching, we show for the first time that teaching attention learners is consistent with teaching importance-adaptive nonparametric learners. These new findings readily commit AtteNT to enhancing learning efficiency of attention learners. Specifically, we observe training time reductions of 13.01% for LLMs and 20.58% for ViTs, spanning both fine-tuning and training-from-scratch regimes. Crucially, these gains are achieved without compromising accuracy; in fact, performance is consistently preserved and often enhanced across a diverse set of downstream tasks.
Paper Structure (30 sections, 4 theorems, 55 equations, 6 figures, 8 tables, 1 algorithm)

This paper contains 30 sections, 4 theorems, 55 equations, 6 figures, 8 tables, 1 algorithm.

Key Result

Theorem 3

Given a convex loss $\mathcal{L}$ and a training set $\{(\bm{S}_i,\bm{y}_i)|\bm{S}_i\in\mathcal{S},\bm{y}_i\in\mathcal{Y}\}_N$, the dynamic ANTK, which is derived from performing gradient descent on the parameters of an ANN, converges pointwise to the importance-adaptive canonical kernel in the dual

Figures (6)

  • Figure 1: An illustration of the workflow for an attention neural network with an input sequence $\bm{S}$.
  • Figure 2: An illustration of the workflow for different multi-output attention learners, with input sequence $\bm{S}$ and $\bm{S}'$ (in the case of cross-attention).
  • Figure 3: Graphical depiction of the ANTK computation process: $K_{\theta}(\bm{S}_{S},\bm{S}'_{S'})=\left\langle\frac{\partial f_{\theta}(\bm{S})}{\partial \theta},\frac{\partial f_{\theta}(\bm{S}')}{\partial \theta} \right\rangle=[\frac{\partial {f_{\theta}(\bm{S})}_{(i,:)}}{\partial \bm{W}^{V}_{(1)}}\frac{\partial {f_{\theta}(\bm{S}')}_{(j,:)}}{\partial \bm{W}^{V}_{(1)}}+\dots+\frac{\partial {f_{\theta}(\bm{S})}_{(i,:)}}{\partial \bm{W}^{V}_{(d)}}\frac{\partial {f_{\theta}(\bm{S}')}_{(j,:)}}{\partial \bm{W}^{V}_{(d)}}+\frac{\partial {f_{\theta}(\bm{S})}_{(i,:)}}{\partial \bm{W}^{Q}_{(1,1)}}\frac{\partial {f_{\theta}(\bm{S}')}_{(j,:)}}{\partial \bm{W}^{Q}_{(1,1)}}+\dots+\frac{\partial {f_{\theta}(\bm{S})}_{(i,:)}}{\partial \bm{W}^{Q}_{(d,p)}}\frac{\partial {f_{\theta}(\bm{S}')}_{(j,:)}}{\partial \bm{W}^{Q}_{(d,p)}}+\frac{\partial {f_{\theta}(\bm{S})}_{(i,:)}}{\partial \bm{W}^{K}_{(1,1)}}\frac{\partial {f_{\theta}(\bm{S}')}_{(j,:)}}{\partial \bm{W}^{K}_{(1,1)}}+\dots+\frac{\partial {f_{\theta}(\bm{S})}_{(i,:)}}{\partial \bm{W}^{K}_{(d,p)}}\frac{\partial {f_{\theta}(\bm{S}')}_{(j,:)}}{\partial \bm{W}^{K}_{(d,p)}}]_{S\times S';i\in \mathbb{N}_S,j\in\mathbb{N}_{S'}}$.
  • Figure 4: Downstream Task Performance vs Sample Ratio.
  • Figure 5: Frobenius norm of the difference between the empirical NTK at different training steps and the canonical kernel.
  • ...and 1 more figures

Theorems & Definitions (6)

  • Definition 1
  • Definition 2
  • Theorem 3
  • Proposition 4
  • Lemma 5
  • Lemma 6