Table of Contents
Fetching ...

Tangent Transformers for Composition, Privacy and Removal

Tian Yu Liu, Aditya Golatkar, Stefano Soatto

TL;DR

TAFT introduces Tangent Transformers by applying a first-order Taylor expansion around a pre-trained initialization, yielding a linear in $\Delta w$ representation: $f^{lin}_w(\cdot)=f_w(\cdot)+\nabla_w f_w(\cdot)\cdot\Delta w$. This enables efficient Jacobian-Vector Products in a single forward pass, making training and inference costs comparable to traditional non-linear Transformers while preserving the same parameter count. The paper demonstrates practical benefits in parallel training, model composition, zero-cost forgetting, and differential privacy, with TAFT achieving near-parity with NLFT on many downstream tasks and offering large speedups in shard-based workflows. Overall, Tangent Transformers provide a scalable, private, and composable alternative to full fine-tuning for large Vision Transformer models, leveraging convex optimization in weight space to enable new forms of model management.

Abstract

We introduce Tangent Attention Fine-Tuning (TAFT), a method for fine-tuning linearized transformers obtained by computing a First-order Taylor Expansion around a pre-trained initialization. We show that the Jacobian-Vector Product resulting from linearization can be computed efficiently in a single forward pass, reducing training and inference cost to the same order of magnitude as its original non-linear counterpart, while using the same number of parameters. Furthermore, we show that, when applied to various downstream visual classification tasks, the resulting Tangent Transformer fine-tuned with TAFT can perform comparably with fine-tuning the original non-linear network. Since Tangent Transformers are linear with respect to the new set of weights, and the resulting fine-tuning loss is convex, we show that TAFT enjoys several advantages compared to non-linear fine-tuning when it comes to model composition, parallel training, machine unlearning, and differential privacy. Our code is available at: https://github.com/tianyu139/tangent-model-composition

Tangent Transformers for Composition, Privacy and Removal

TL;DR

TAFT introduces Tangent Transformers by applying a first-order Taylor expansion around a pre-trained initialization, yielding a linear in representation: . This enables efficient Jacobian-Vector Products in a single forward pass, making training and inference costs comparable to traditional non-linear Transformers while preserving the same parameter count. The paper demonstrates practical benefits in parallel training, model composition, zero-cost forgetting, and differential privacy, with TAFT achieving near-parity with NLFT on many downstream tasks and offering large speedups in shard-based workflows. Overall, Tangent Transformers provide a scalable, private, and composable alternative to full fine-tuning for large Vision Transformer models, leveraging convex optimization in weight space to enable new forms of model management.

Abstract

We introduce Tangent Attention Fine-Tuning (TAFT), a method for fine-tuning linearized transformers obtained by computing a First-order Taylor Expansion around a pre-trained initialization. We show that the Jacobian-Vector Product resulting from linearization can be computed efficiently in a single forward pass, reducing training and inference cost to the same order of magnitude as its original non-linear counterpart, while using the same number of parameters. Furthermore, we show that, when applied to various downstream visual classification tasks, the resulting Tangent Transformer fine-tuned with TAFT can perform comparably with fine-tuning the original non-linear network. Since Tangent Transformers are linear with respect to the new set of weights, and the resulting fine-tuning loss is convex, we show that TAFT enjoys several advantages compared to non-linear fine-tuning when it comes to model composition, parallel training, machine unlearning, and differential privacy. Our code is available at: https://github.com/tianyu139/tangent-model-composition
Paper Structure (27 sections, 16 equations, 4 figures, 12 tables)

This paper contains 27 sections, 16 equations, 4 figures, 12 tables.

Figures (4)

  • Figure 1: (a) We show that when number of samples to forget is small, we can simply remove shards by subtracting the weights of their respective component model with minimal drop in final model accuracy (computed as an expectation over a uniform distribution of sample forgetting requests). (b) We compare against SISA bourtoule2021machine which also uses a sharding technique for zero-cost unlearning. Our method is uniformly better across all number of shards removed on all datasets. (c) Retraining on the remaining samples in a shard after a forgetting request can further improve accuracy of the "unlearned" model, while enjoying up to $50\times$ faster training time compared to full re-training.
  • Figure 2: (a) RSL can improve fine-tuning performance, beating CE and MSE by $1.5\%$ and $9.0\%$ respectively across 7 datasets. (b) While computing the tangent model about the full pre-training initialization is already effective on its own, re-initializing the weights of the last attention block before linearization can yield further performance gains. (c) Linearizing the CLS token improves accuracy on downstream datasets which are far from the pre-training tasks.
  • Figure 3: Shard re-training with TAFT (using sharding factor of 50) compared to the Paragon method of re-training the non-linear model from scratch. While both method guarantee complete unlearning, TAFT achieves close-to-paragon performance while speeding up unlearning by up to 50x.
  • Figure 4: We plot the $L2$ change in weight space as a result of adding a new component model against number of existing models in the composition. The impact of adding a new model is significantly larger when number of existing component models is small. Note that while plotted on the same graph, the difference in scale between different datasets are not meant to be directly comparable due to difference in number of output classes, amongst other factors.