Table of Contents
Fetching ...

Heterogeneous LoRA for Federated Fine-tuning of On-Device Foundation Models

Yae Jee Cho, Luyang Liu, Zheng Xu, Aldi Fahrezi, Gauri Joshi

TL;DR

This work tackles privacy-preserving federated fine-tuning of on-device foundation models by introducing HetLoRA, which permits heterogeneous LoRA ranks across clients and uses rank self-pruning and sparsity-weighted aggregation to combine updates. By formalizing the FL objective with variable per-client ranks and implementing a three-step HetLoRA workflow—distribution via truncation, local rank-pruning, and weighted aggregation—the approach balances convergence speed and generalization while dramatically reducing trainable parameters compared to full fine-tuning. Empirical results on PaLM 2 XXS/XS with Reddit RougeL and multi-session chat perplexity show HetLoRA outperforming homogeneous-LoRA baselines and reconstruction-based methods, while achieving near-full fine-tuning performance at a fraction of the communication and computation cost. The proposed method demonstrates practical potential for on-device privacy-preserving adaptation of small-to-mid sized foundation models in heterogeneous device ecosystems, and sets the stage for theoretical convergence and rank-assignment strategies in future work.

Abstract

Foundation models (FMs) adapt well to specific domains or tasks with fine-tuning, and federated learning (FL) enables the potential for privacy-preserving fine-tuning of the FMs with on-device local data. For federated fine-tuning of FMs, we consider the FMs with small to medium parameter sizes of single digit billion at maximum, referred to as on-device FMs (ODFMs) that can be deployed on devices for inference but can only be fine-tuned with parameter efficient methods. In our work, we tackle the data and system heterogeneity problem of federated fine-tuning of ODFMs by proposing a novel method using heterogeneous low-rank approximations (LoRAs), namely HetLoRA. First, we show that the naive approach of using homogeneous LoRA ranks across devices face a trade-off between overfitting and slow convergence, and thus propose HetLoRA, which allows heterogeneous ranks across client devices and efficiently aggregates and distributes these heterogeneous LoRA modules. By applying rank self-pruning locally and sparsity-weighted aggregation at the server, HetLoRA combines the advantages of high and low-rank LoRAs, which achieves improved convergence speed and final performance compared to homogeneous LoRA. Furthermore, HetLoRA offers enhanced computation efficiency compared to full fine-tuning, making it suitable for federated fine-tuning across heterogeneous devices.

Heterogeneous LoRA for Federated Fine-tuning of On-Device Foundation Models

TL;DR

This work tackles privacy-preserving federated fine-tuning of on-device foundation models by introducing HetLoRA, which permits heterogeneous LoRA ranks across clients and uses rank self-pruning and sparsity-weighted aggregation to combine updates. By formalizing the FL objective with variable per-client ranks and implementing a three-step HetLoRA workflow—distribution via truncation, local rank-pruning, and weighted aggregation—the approach balances convergence speed and generalization while dramatically reducing trainable parameters compared to full fine-tuning. Empirical results on PaLM 2 XXS/XS with Reddit RougeL and multi-session chat perplexity show HetLoRA outperforming homogeneous-LoRA baselines and reconstruction-based methods, while achieving near-full fine-tuning performance at a fraction of the communication and computation cost. The proposed method demonstrates practical potential for on-device privacy-preserving adaptation of small-to-mid sized foundation models in heterogeneous device ecosystems, and sets the stage for theoretical convergence and rank-assignment strategies in future work.

Abstract

Foundation models (FMs) adapt well to specific domains or tasks with fine-tuning, and federated learning (FL) enables the potential for privacy-preserving fine-tuning of the FMs with on-device local data. For federated fine-tuning of FMs, we consider the FMs with small to medium parameter sizes of single digit billion at maximum, referred to as on-device FMs (ODFMs) that can be deployed on devices for inference but can only be fine-tuned with parameter efficient methods. In our work, we tackle the data and system heterogeneity problem of federated fine-tuning of ODFMs by proposing a novel method using heterogeneous low-rank approximations (LoRAs), namely HetLoRA. First, we show that the naive approach of using homogeneous LoRA ranks across devices face a trade-off between overfitting and slow convergence, and thus propose HetLoRA, which allows heterogeneous ranks across client devices and efficiently aggregates and distributes these heterogeneous LoRA modules. By applying rank self-pruning locally and sparsity-weighted aggregation at the server, HetLoRA combines the advantages of high and low-rank LoRAs, which achieves improved convergence speed and final performance compared to homogeneous LoRA. Furthermore, HetLoRA offers enhanced computation efficiency compared to full fine-tuning, making it suitable for federated fine-tuning across heterogeneous devices.
Paper Structure (11 sections, 6 figures, 4 tables)

This paper contains 11 sections, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Overview of heterogeneous rank deployment of LoRA: the pretrained weights $\mathbf{W}_0$ are stored on-device and heterogeneous ranks are assigned to different clients with $r_{\text{min}}=r_1<r_2<r_3=r_{\text{max}}$. In our proposed HetLoRA, the server receives the trained heterogeneous LoRA modules and aggregates them with sparsity-weighted aggregation to update the global LoRA module.
  • Figure 2: Overview of the zero-padding, sparsity-weighted aggregation, and truncation method for HetLoRA; (a): Zero-pad LoRA modules with smaller ranks to $r_{\text{max}}$ (clients with rank $r_{\text{max}}$ does not need padding) and calculate their sparsity by calculating the Frobenius norm of the reconstructed model $\Delta\mathbf{W}_k^{(t)}=\mathbf{B}_k^{(t)}\mathbf{A}_k^{(t)}$; (b): After padding, aggregate all of the clients' LoRA modules with the weights $\|\mathbf{S}_{k}^{(t)}\|/\mathbf{Z}^{(t)}$ calculated by $\Delta\mathbf{W}_k^{(t)}$ to get the global LoRA modules; (c): Truncate the global LoRA modules for the specific rank of the next selected client (example for client with rank $r_2$).
  • Figure 3: Performance of homogeneous LoRA for different rank $r$. Higher ranks achieve better performance with fewer communication rounds than the lower ranks, but they overfit quickly. Conversely, the lowest rank $r=1$ achieves low perplexity slower than higher ranks, but without overfitting.
  • Figure 4: Performance of HetLoRA without rank pruning or and with simple average aggregation. Similar to homogeneous LoRA, larger $r_{\text{min}}$ leads to overfitting for heterogeneous LoRA, but it is not as severe as homogeneous LoRA even for larger maximum rank $r_{\text{max}}=50$ showing that the smaller rank LoRA modules act as a regularizer for HetLoRA.
  • Figure 5: Comparison of the performance across homogeneous LoRA, heterogeneous LoRA, and full fine-tuning. Heterogeneous LoRA achieves better performance than homogeneous LoRA with fewer number of communication rounds.
  • ...and 1 more figures