Table of Contents
Fetching ...

Continual Fine-Tuning with Provably Accurate and Parameter-Free Task Retrieval

Hang Thi-Thuy Le, Long Minh Bui, Minh Hoang, Trong Nghia Hoang

Abstract

Continual fine-tuning aims to adapt a pre-trained backbone to new tasks sequentially while preserving performance on earlier tasks whose data are no longer available. Existing approaches fall into two categories which include input- and parameter-adaptation. Input-adaptation methods rely on retrieving the most relevant prompts at test time, but require continuously learning a retrieval function that is prone to forgetting. Parameter-adaptation methods instead use a fixed input embedding function to enable retrieval-free prediction and avoid forgetting, but sacrifice representation adaptability. To combine their best strengths, we propose a new parameter-adaptation method that enables adaptive use of input embeddings during test time with parameter-free retrieval. We derive task-retrieval error bounds for a clustering-based, parameter-free paradigm, providing theoretical guarantees that link low retrieval error to structural properties of task-specific representation clusters, revealing a fresh insight into how well-organized clustering structure will enable reliable retrieval. Motivated by this insight, our method is designed with two key components: (i) an adaptive module composition strategy that learns informative task-specific updates to preserve and complement prior knowledge, and (ii) a clustering-based retrieval mechanism that captures distinct representation signatures for each task, enabling adaptive representation use at test time. Extensive experiments show that these components work synergistically to improve retrieval and predictive performance under large shifts in task semantics.

Continual Fine-Tuning with Provably Accurate and Parameter-Free Task Retrieval

Abstract

Continual fine-tuning aims to adapt a pre-trained backbone to new tasks sequentially while preserving performance on earlier tasks whose data are no longer available. Existing approaches fall into two categories which include input- and parameter-adaptation. Input-adaptation methods rely on retrieving the most relevant prompts at test time, but require continuously learning a retrieval function that is prone to forgetting. Parameter-adaptation methods instead use a fixed input embedding function to enable retrieval-free prediction and avoid forgetting, but sacrifice representation adaptability. To combine their best strengths, we propose a new parameter-adaptation method that enables adaptive use of input embeddings during test time with parameter-free retrieval. We derive task-retrieval error bounds for a clustering-based, parameter-free paradigm, providing theoretical guarantees that link low retrieval error to structural properties of task-specific representation clusters, revealing a fresh insight into how well-organized clustering structure will enable reliable retrieval. Motivated by this insight, our method is designed with two key components: (i) an adaptive module composition strategy that learns informative task-specific updates to preserve and complement prior knowledge, and (ii) a clustering-based retrieval mechanism that captures distinct representation signatures for each task, enabling adaptive representation use at test time. Extensive experiments show that these components work synergistically to improve retrieval and predictive performance under large shifts in task semantics.
Paper Structure (28 sections, 4 theorems, 33 equations, 9 figures, 18 tables, 2 algorithms)

This paper contains 28 sections, 4 theorems, 33 equations, 9 figures, 18 tables, 2 algorithms.

Key Result

Theorem 3.4

[Bounding Retrieval Error Rate] Suppose we are given $n$ training tasks with clustered data distributions $(D_1^t)_{t=1}^\tau, (D_2^t)_{t=1}^\tau, \ldots, (D_n^t)_{t=1}^\tau$. Let $\boldsymbol{E}_k^t$ denotes the event that a test input $\boldsymbol{x} \sim D_k^t$, where $D_k^t$ is the $t$-th sub-te This means the error rate decreases exponentially in $\delta d \geq 0$ where $\kappa$ denotes the m

Figures (9)

  • Figure 1: The overall workflow of Proteus. A: A LoRA unit is optimized for each task as a combination of previous tuning directions and new orthogonal components to capture task-specific information (Section \ref{['sec:adaptive_ft']}). B: A database of values (past update directions) and multi-keys (GMM parameters fit to each task embedding distribution) is updated. At test time, this database is used to retrieve most informative past updates per test input (Section \ref{['sec:adaptive_ft']}). C: LDA prediction (Section \ref{['sec:adaptive_ft']}).
  • Figure 2: Box plots showing the distribution of our empirical cluster separation factor $\delta$ clearly lies above the distribution of its minimum requirement to guarantee low retrieval error rate $\epsilon$ across tasks in the Split CIFAR-100 benchmark.
  • Figure 3: Proteus has the most stable performance and highest final accuracies on $3$ datasets: CIFAR-100, VTAB5T-small and VTAB-sim50.
  • Figure 4: Task retrieval accuracies achieved by PROTEUS and its variant w/o adaptive fine-tuning and enforcing orthogonality among fine-tuning (FT) directions. It can be observed that w/o such conditioning, the task retrieval severely decreases.
  • Figure 5: Heatmap visualizations of $\boldsymbol{S}$ at different iterations on ImageNet-R with $\lambda > 0$ and $\lambda=0$.
  • ...and 4 more figures

Theorems & Definitions (6)

  • Definition 3.1: Multi-Key Signature
  • Definition 3.3: Cluster Separation Factor
  • Theorem 3.4
  • Theorem 3.5
  • Theorem A.2
  • Theorem A.2