Table of Contents
Fetching ...

Memory-Efficient Fine-Tuning of Transformers via Token Selection

Antoine Simoulin, Namyong Park, Xiaoyi Liu, Grey Yang

TL;DR

TokenTune tackles the memory bottleneck in fine-tuning large transformers by performing backpropagation through a randomly selected subset of input tokens, thereby drastically reducing activation memory while preserving task performance. The method generalizes to dense, normalization, and attention layers and can be combined with existing memory- and parameter-efficient fine-tuning approaches such as LoRA and QLoRA. Empirical results on medium- and large-scale models show TokenTune achieves accuracy on par with full fine-tuning or other memory-efficient methods, while reducing memory footprint; for Llama2-7B, memory can be cut by up to roughly $60\%$ when used alone, and even further when combined with LoRA/QLoRA. The approach is model- and task-agnostic within the transformer family and holds practical value for domain-specific fine-tuning and system-level co-training, with available code at the provided repository.

Abstract

Fine-tuning provides an effective means to specialize pre-trained models for various downstream tasks. However, fine-tuning often incurs high memory overhead, especially for large transformer-based models, such as LLMs. While existing methods may reduce certain parts of the memory required for fine-tuning, they still require caching all intermediate activations computed in the forward pass to update weights during the backward pass. In this work, we develop TokenTune, a method to reduce memory usage, specifically the memory to store intermediate activations, in the fine-tuning of transformer-based models. During the backward pass, TokenTune approximates the gradient computation by backpropagating through just a subset of input tokens. Thus, with TokenTune, only a subset of intermediate activations are cached during the forward pass. Also, TokenTune can be easily combined with existing methods like LoRA, further reducing the memory cost. We evaluate our approach on pre-trained transformer models with up to billions of parameters, considering the performance on multiple downstream tasks such as text classification and question answering in a few-shot learning setup. Overall, TokenTune achieves performance on par with full fine-tuning or representative memory-efficient fine-tuning methods, while greatly reducing the memory footprint, especially when combined with other methods with complementary memory reduction mechanisms. We hope that our approach will facilitate the fine-tuning of large transformers, in specializing them for specific domains or co-training them with other neural components from a larger system. Our code is available at https://github.com/facebookresearch/tokentune.

Memory-Efficient Fine-Tuning of Transformers via Token Selection

TL;DR

TokenTune tackles the memory bottleneck in fine-tuning large transformers by performing backpropagation through a randomly selected subset of input tokens, thereby drastically reducing activation memory while preserving task performance. The method generalizes to dense, normalization, and attention layers and can be combined with existing memory- and parameter-efficient fine-tuning approaches such as LoRA and QLoRA. Empirical results on medium- and large-scale models show TokenTune achieves accuracy on par with full fine-tuning or other memory-efficient methods, while reducing memory footprint; for Llama2-7B, memory can be cut by up to roughly when used alone, and even further when combined with LoRA/QLoRA. The approach is model- and task-agnostic within the transformer family and holds practical value for domain-specific fine-tuning and system-level co-training, with available code at the provided repository.

Abstract

Fine-tuning provides an effective means to specialize pre-trained models for various downstream tasks. However, fine-tuning often incurs high memory overhead, especially for large transformer-based models, such as LLMs. While existing methods may reduce certain parts of the memory required for fine-tuning, they still require caching all intermediate activations computed in the forward pass to update weights during the backward pass. In this work, we develop TokenTune, a method to reduce memory usage, specifically the memory to store intermediate activations, in the fine-tuning of transformer-based models. During the backward pass, TokenTune approximates the gradient computation by backpropagating through just a subset of input tokens. Thus, with TokenTune, only a subset of intermediate activations are cached during the forward pass. Also, TokenTune can be easily combined with existing methods like LoRA, further reducing the memory cost. We evaluate our approach on pre-trained transformer models with up to billions of parameters, considering the performance on multiple downstream tasks such as text classification and question answering in a few-shot learning setup. Overall, TokenTune achieves performance on par with full fine-tuning or representative memory-efficient fine-tuning methods, while greatly reducing the memory footprint, especially when combined with other methods with complementary memory reduction mechanisms. We hope that our approach will facilitate the fine-tuning of large transformers, in specializing them for specific domains or co-training them with other neural components from a larger system. Our code is available at https://github.com/facebookresearch/tokentune.

Paper Structure

This paper contains 38 sections, 7 equations, 4 figures, 7 tables, 1 algorithm.

Figures (4)

  • Figure 1: TokenTune greatly reduces the GPU memory usage for fine-tuning the Llama2-7B model (e.g., using only 37% of the memory QLoRA dettmers_23 requires), while achieving similar accuracy to representative memory-efficient fine-tuning methods. Accuracy and memory usage numbers are listed in \ref{['tab:llm-perf']} and Fig. \ref{['fig:llm-memory']}. See Sec. \ref{['sec:exp:large']} for details on experiments.
  • Figure 2: TokenTune achieves memory-efficient fine-tuning of transformers via token selection. During the backward pass, we compute the gradient for only a subset of $k$ input tokens, while the others are frozen (in gray in the figure). During the forward pass, all input positions are used, but only a subset of the activations is cached in memory (in blue in the figure). TokenTune is applicable to various transformer-based models, as well as different language modeling tasks, as our experiments with Bertdevlin_19 and Llama touvron_23 show.
  • Figure 3: (left) We plot the GPU memory required to train Bert-base on the CoLA task given varying batch sizes. We compare our approach with two PEFT approaches: Ladder Side Tuning (LST) and LoRA. (right) We plot the mean and standard deviation performance on the dev set of five runs when training Bert-base on two tasks from the GLUE benchmark: MRPC and STS-B. We use our memory efficient fine-tuning approach with a different number of selected input tokens for the gradient computation.
  • Figure 4: GPU memory required to fine-tune Llama2-7B touvron_23. We measure the memory by fine-tuning the model on artificially generated data with a given sequence length and batch size. We set the batch size to 1 and the sequence length to 2048. We show the memory usage when combining TokenTune with LoRA and QLoRA and plot the evolution of the memory required to fine-tune the model on a H100 GPU with a number of trained positions ranging between 256 and 2046 (we leave at least 2 positions not tuned). Since we could not perform full fine-tuning on our hardware, we estimate the full fine-tuning memory based on the memory reported for TokenTune and LoRA. Specific memory usage values can be found in \ref{['tab:gpu_mem_usage']}.