Table of Contents
Fetching ...

Deep Hierarchical Learning with Nested Subspace Networks for Large Language Models

Paulius Rauba, Mihaela van der Schaar

TL;DR

Nested Subspace Networks are proposed, a novel architectural paradigm that enables a single model to be dynamically and granularly adjusted across a continuous spectrum of compute budgets at inference time and establish NSNs as a powerful framework for creating the next generation of adaptive foundation models.

Abstract

Large neural networks are typically trained for a fixed computational budget, creating a rigid trade-off between performance and efficiency that is ill-suited for deployment in resource-constrained or dynamic environments. Existing approaches to this problem present a difficult choice: training a discrete collection of specialist models is computationally prohibitive, while dynamic methods like slimmable networks often lack the flexibility to be applied to large, pre-trained foundation models. In this work, we propose Nested Subspace Networks (NSNs), a novel architectural paradigm that enables a single model to be dynamically and granularly adjusted across a continuous spectrum of compute budgets at inference time. The core of our approach is to re-parameterize linear layers to satisfy a nested subspace property, such that the function computed at a given rank is a strict subspace of the function at any higher rank. We show that this entire hierarchy of models can be optimized jointly via an uncertainty-aware objective that learns to balance the contributions of different ranks based on their intrinsic difficulty. We demonstrate empirically that NSNs can be surgically applied to pre-trained LLMs and unlock a smooth and predictable compute-performance frontier. For example, a single NSN-adapted model can achieve a 50% reduction in inference FLOPs with only a 5 percentage point loss in accuracy. Our findings establish NSNs as a powerful framework for creating the next generation of adaptive foundation models.

Deep Hierarchical Learning with Nested Subspace Networks for Large Language Models

TL;DR

Nested Subspace Networks are proposed, a novel architectural paradigm that enables a single model to be dynamically and granularly adjusted across a continuous spectrum of compute budgets at inference time and establish NSNs as a powerful framework for creating the next generation of adaptive foundation models.

Abstract

Large neural networks are typically trained for a fixed computational budget, creating a rigid trade-off between performance and efficiency that is ill-suited for deployment in resource-constrained or dynamic environments. Existing approaches to this problem present a difficult choice: training a discrete collection of specialist models is computationally prohibitive, while dynamic methods like slimmable networks often lack the flexibility to be applied to large, pre-trained foundation models. In this work, we propose Nested Subspace Networks (NSNs), a novel architectural paradigm that enables a single model to be dynamically and granularly adjusted across a continuous spectrum of compute budgets at inference time. The core of our approach is to re-parameterize linear layers to satisfy a nested subspace property, such that the function computed at a given rank is a strict subspace of the function at any higher rank. We show that this entire hierarchy of models can be optimized jointly via an uncertainty-aware objective that learns to balance the contributions of different ranks based on their intrinsic difficulty. We demonstrate empirically that NSNs can be surgically applied to pre-trained LLMs and unlock a smooth and predictable compute-performance frontier. For example, a single NSN-adapted model can achieve a 50% reduction in inference FLOPs with only a 5 percentage point loss in accuracy. Our findings establish NSNs as a powerful framework for creating the next generation of adaptive foundation models.

Paper Structure

This paper contains 74 sections, 3 theorems, 38 equations, 15 figures, 3 tables, 4 algorithms.

Key Result

Proposition 1

Let the task loss function $\mathcal{L}(f(\mathbf{x};r), y)$ be $L_{\mathcal{L}}$-Lipschitz continuous with respect to its first argument. Let $E(r) = \mathbb{E}_{(\mathbf{x},y)}[\mathcal{L}(f(\mathbf{x}; r), y)]$ be the expected error at rank $r$. For any ranks $r_1 < r_{\text{int}} < R$, the diffe where $C = L_{\mathcal{L}} \cdot \mathbb{E}[\left\lVert\mathbf{x}\right\rVert]$ is a task-dependent

Figures (15)

  • Figure 1: Illustration of Nested Subspace Networks. NSNs convert linear layers into rank-trainable layers which enable dynamic control over the computational cost (FLOPs) of a forward pass. Left: Standard MLP layers that are composed of trainable weights. Middle: LoRA fine-tuning which have frozen weights and trainable adapters. Right: Nested Subspace Networks replace each linear layer with a single pair of shared factor matrices $(A,B)$ defining a rank-trainable layer. The effective weight at rank $r$, $W_r = B_r A_r$, is obtained by using only the first $r$ rows of $A$ and first $r$ columns of $B$. Different operating points (different ranks) therefore correspond to using different prefixes of the same $(A,B)$ This allows for the construction of a compute-performance Pareto frontier at inference.
  • Figure 2: Comparison of native-rank training and rank truncation for an MLP on CIFAR-10. The plot compares the accuracy of individually training a model for each specific rank (Native rank training) versus training a single model at a high rank (64) and truncating it to lower ranks at test time (Rank-64 training). The significant performance gap demonstrates that naively truncating a high-rank model results in poor performance.
  • Figure 3: Learned log-variances during training with the multi-rank uncertainty objective on CIFAR-10 dataset. We train a single model with an anchor and variant ranks and find that higher ranks have lower task-dependent uncertainty during training.
  • Figure 4: Energy decay profiles. The energy decay assumption holds in low-rank linear layers trained with our cross-entropy objective but does not hold in a standard MLP setting.
  • Figure 5: Training Dynamics at Interpolated Ranks. Validation accuracy during training for various objective functions, evaluated at ranks not explicitly optimized for. Our proposed method maintains stable learning across all ranks, while simpler baselines exhibit instability and performance collapse.
  • ...and 10 more figures

Theorems & Definitions (5)

  • Proposition 1: Bound on Interpolation Error
  • Lemma 1: Sufficient condition for nested subspace property
  • Lemma 2: Adjacent Rank Perturbation
  • proof
  • proof : Proof of Proposition \ref{['prop:interp_err']}