Table of Contents
Fetching ...

Supervised Fine-Tuning Achieve Rapid Task Adaption Via Alternating Attention Head Activation Patterns

Yang Zhao, Li Du, Xiao Ding, Kai Xiong, Ting Liu, Bing Qin

TL;DR

A gradient-based method is employed, to dissect the process that the SFT process adapts LLMs to downstream tasks via the perspective of attention patterns, and finds that LLMs selectively activate task-specific attention heads during SFT and activation patterns for complex tasks are combinations of basic task patterns.

Abstract

LLMs' performance on complex tasks is still unsatisfactory. A key issue is that presently LLMs learn in a data-driven schema, while the instructions about these complex tasks are both scarce and hard to collect or construct. On the contrary, a prominent phenomenon is that LLMs can learn rather fast on simpler tasks with adequate prior knowledge captured during pretraining stage. Thus, if the prerequisite and mechanism of such rapid generalization could be elucidated, it could enhance the efficiency and effectiveness of the LLM's ability to learn complex tasks. Thus, in this paper, we employ a gradient-based method, to dissect the process that the SFT process adapts LLMs to downstream tasks via the perspective of attention patterns. We find that: (1) LLMs selectively activate task-specific attention heads during SFT; (2) activation patterns for complex tasks are combinations of basic task patterns; and (3) changes in a few parameters can significantly impact activation patterns after SFT on a small number of samples.Based on these insights, experiments are conducted to actually enhance the efficiency and effectiveness of SFT.

Supervised Fine-Tuning Achieve Rapid Task Adaption Via Alternating Attention Head Activation Patterns

TL;DR

A gradient-based method is employed, to dissect the process that the SFT process adapts LLMs to downstream tasks via the perspective of attention patterns, and finds that LLMs selectively activate task-specific attention heads during SFT and activation patterns for complex tasks are combinations of basic task patterns.

Abstract

LLMs' performance on complex tasks is still unsatisfactory. A key issue is that presently LLMs learn in a data-driven schema, while the instructions about these complex tasks are both scarce and hard to collect or construct. On the contrary, a prominent phenomenon is that LLMs can learn rather fast on simpler tasks with adequate prior knowledge captured during pretraining stage. Thus, if the prerequisite and mechanism of such rapid generalization could be elucidated, it could enhance the efficiency and effectiveness of the LLM's ability to learn complex tasks. Thus, in this paper, we employ a gradient-based method, to dissect the process that the SFT process adapts LLMs to downstream tasks via the perspective of attention patterns. We find that: (1) LLMs selectively activate task-specific attention heads during SFT; (2) activation patterns for complex tasks are combinations of basic task patterns; and (3) changes in a few parameters can significantly impact activation patterns after SFT on a small number of samples.Based on these insights, experiments are conducted to actually enhance the efficiency and effectiveness of SFT.
Paper Structure (29 sections, 10 equations, 5 figures, 3 tables)

This paper contains 29 sections, 10 equations, 5 figures, 3 tables.

Figures (5)

  • Figure 1: Visualization of activation pattern changes in Llama3-8B on the test set before and after SFT with the GSM8K training set.
  • Figure 2: The correlation coefficients of the activation pattern change rates for the Llama3-8B, Gemma-7B, and OPT-6.7B models on tasks before SFT, after SFT, and during the SFT process (corresponding to the top, middle, and bottom sections, respectively).
  • Figure 3: The two tasks above are basic tasks that each rely on a single skill, while the one below is a complex task that relies on both coding and mathematical skills.
  • Figure 4: Top: The least squares method was used to fit the activation pattern changes of traditional NLP tasks in SFT to the activation pattern changes in SGSM. A higher $R^2$ indicates a better fit, with Code Search Net and GSM8k showing the highest $R^2$ values. Bottom: The least squares method was used to fit the activation pattern changes of SFT instruction data to those of tasks requiring both “logical reasoning” skills and “programming and software development” instructions. The combination of “logical reasoning” skills and “programming” instructions achieved the highest $R^2$ value.
  • Figure 5: Tracking changes in correlation coefficient and MSE activation patterns of the Llama3-8B, Gemma-7B, and OPT-6.7B models during fine-tuning on datasets including Code Search Net, GSM8k, MATH, SGSM, ARC, HellaSwag, and Winogrande.