Table of Contents
Fetching ...

Closed-Form Last Layer Optimization

Alexandre Galashov, Nathaël Da Costa, Liyuan Xu, Philipp Hennig, Arthur Gretton

TL;DR

This work introduces a training paradigm that treats the last linear layer of a neural network as its closed-form ridge solution $W^*(\theta)$ conditioned on the backbone features $\phi_\theta$, and optimizes only the backbone parameters $\theta$. By formulating the reduced loss $\mathcal{L}^*(\theta)=\mathcal{L}(W^*(\theta),\theta)$ and leveraging an envelope-theorem argument, the method avoids backpropagating through the inverse while remaining compatible with stochastic gradient updates. To address stochasticity and batch-overfitting, a proximal regularization is added, leading to an SGD update that alternates between solving a proximal last-layer problem and gradient steps on the backbone; a Kalman-filter interpretation is also provided. The authors prove convergence in the infinite-width NTK regime under favorable conditions and demonstrate empirical improvements over standard SGD with squared loss across regression, deep feature instrumental variable regression (DFIV), and CIFAR classifications, though ImageNet results show mixed performance against cross-entropy, indicating regime-dependent benefits and future potential for cross-entropy extension. Overall, the approach offers a computationally efficient way to maintain an optimal last layer during training, accelerating convergence and reducing overfitting on small batches while preserving accuracy on larger-scale tasks.

Abstract

Neural networks are typically optimized with variants of stochastic gradient descent. Under a squared loss, however, the optimal solution to the linear last layer weights is known in closed-form. We propose to leverage this during optimization, treating the last layer as a function of the backbone parameters, and optimizing solely for these parameters. We show this is equivalent to alternating between gradient descent steps on the backbone and closed-form updates on the last layer. We adapt the method for the setting of stochastic gradient descent, by trading off the loss on the current batch against the accumulated information from previous batches. Further, we prove that, in the Neural Tangent Kernel regime, convergence of this method to an optimal solution is guaranteed. Finally, we demonstrate the effectiveness of our approach compared with standard SGD on a squared loss in several supervised tasks -- both regression and classification -- including Fourier Neural Operators and Instrumental Variable Regression.

Closed-Form Last Layer Optimization

TL;DR

This work introduces a training paradigm that treats the last linear layer of a neural network as its closed-form ridge solution conditioned on the backbone features , and optimizes only the backbone parameters . By formulating the reduced loss and leveraging an envelope-theorem argument, the method avoids backpropagating through the inverse while remaining compatible with stochastic gradient updates. To address stochasticity and batch-overfitting, a proximal regularization is added, leading to an SGD update that alternates between solving a proximal last-layer problem and gradient steps on the backbone; a Kalman-filter interpretation is also provided. The authors prove convergence in the infinite-width NTK regime under favorable conditions and demonstrate empirical improvements over standard SGD with squared loss across regression, deep feature instrumental variable regression (DFIV), and CIFAR classifications, though ImageNet results show mixed performance against cross-entropy, indicating regime-dependent benefits and future potential for cross-entropy extension. Overall, the approach offers a computationally efficient way to maintain an optimal last layer during training, accelerating convergence and reducing overfitting on small batches while preserving accuracy on larger-scale tasks.

Abstract

Neural networks are typically optimized with variants of stochastic gradient descent. Under a squared loss, however, the optimal solution to the linear last layer weights is known in closed-form. We propose to leverage this during optimization, treating the last layer as a function of the backbone parameters, and optimizing solely for these parameters. We show this is equivalent to alternating between gradient descent steps on the backbone and closed-form updates on the last layer. We adapt the method for the setting of stochastic gradient descent, by trading off the loss on the current batch against the accumulated information from previous batches. Further, we prove that, in the Neural Tangent Kernel regime, convergence of this method to an optimal solution is guaranteed. Finally, we demonstrate the effectiveness of our approach compared with standard SGD on a squared loss in several supervised tasks -- both regression and classification -- including Fourier Neural Operators and Instrumental Variable Regression.

Paper Structure

This paper contains 39 sections, 4 theorems, 48 equations, 15 figures, 3 algorithms.

Key Result

Theorem 1

For fixed $\theta$, letting $W^\star := W^\star(\theta)$ with eq:closed-form-w, we have

Figures (15)

  • Figure 1: The squared loss landscape of a two-parameter neural network
  • Figure 2: Regression results. X-axis is the number of iterations, Y-axis is a test set mean squared error (MSE), columns represents different batch sizes. Different colors indicate different methods. We use a rolling average with window size $5$ to smooth the curves.
  • Figure 3: DFIV results. X-axis is the number of iterations, Y-axis is a test set MSE. Each column corresponds to a different batch size. Different colors indicate different methods. Solid lines use the last layer re-estimated on the entire training set, while dashed lines use current last layer estimates. We use a rolling average with window size $5$ to smooth the curves.
  • Figure 4: CIFAR-10 results. X-axis is the number of iterations, Y-axis is a test set accuracy. Each column corresponds to a different batch size. Different colors indicate different methods.
  • Figure 5: CIFAR-100 results. X-axis is the number of iterations, Y-axis is a test set accuracy. Each column corresponds to a different batch size. Different colors indicate different methods.
  • ...and 10 more figures

Theorems & Definitions (6)

  • Theorem 1
  • proof
  • Theorem 2
  • proof
  • Theorem 3
  • Theorem 4