Table of Contents
Fetching ...

ORFit: One-Pass Learning via Bridging Orthogonal Gradient Descent and Recursive Least-Squares

Youngjae Min, Namhoon Cho, Navid Azizan

TL;DR

ORFit tackles one-pass learning for streaming data by updating parameters orthogonally to past gradients while interpolating the new datapoint, ensuring minimal forgetting. It combines an orthogonal update with an IPCA-based memory budget to maintain linear memory in the number of parameters and extends to batch updates. Theoretically, ORFit matches the multi-pass SGD solution for overparameterized linear models and its principal directions are minimax-optimal for forgetting; in the NTK regime it extends to nonlinear networks with similar guarantees. Empirically, ORFit improves forgetting behavior under memory constraints and achieves competitive accuracy with reduced memory and computation compared to baselines, with applicability to deep networks via NTK-inspired analyses.

Abstract

While large machine learning models have shown remarkable performance in various domains, their training typically requires iterating for many passes over the training data. However, due to computational and memory constraints and potential privacy concerns, storing and accessing all the data is impractical in many real-world scenarios where the data arrives in a stream. In this paper, we investigate the problem of one-pass learning, in which a model is trained on sequentially arriving data without retraining on previous datapoints. Motivated by the demonstrated effectiveness of overparameterized models and the phenomenon of benign overfitting, we propose Orthogonal Recursive Fitting (ORFit), an algorithm for one-pass learning which seeks to perfectly fit each new datapoint while minimally altering the predictions on previous datapoints. ORFit updates the parameters in a direction orthogonal to past gradients, similar to orthogonal gradient descent (OGD) in continual learning. We show that, interestingly, ORFit's update leads to an operation similar to the recursive least-squares (RLS) algorithm in adaptive filtering but with significantly improved memory and computational efficiency, i.e., linear, instead of quadratic, in the number of parameters. To further reduce memory usage, we leverage the structure of the streaming data via an incremental principal component analysis (IPCA). We show that using the principal components is minimax optimal, i.e., it minimizes the worst-case forgetting of previous predictions for unknown future updates. Further, we prove that, for overparameterized linear models, the parameter vector obtained by ORFit matches what the standard multi-pass stochastic gradient descent (SGD) would converge to. Finally, we extend our results to the nonlinear setting for highly overparameterized models, relevant for deep learning.

ORFit: One-Pass Learning via Bridging Orthogonal Gradient Descent and Recursive Least-Squares

TL;DR

ORFit tackles one-pass learning for streaming data by updating parameters orthogonally to past gradients while interpolating the new datapoint, ensuring minimal forgetting. It combines an orthogonal update with an IPCA-based memory budget to maintain linear memory in the number of parameters and extends to batch updates. Theoretically, ORFit matches the multi-pass SGD solution for overparameterized linear models and its principal directions are minimax-optimal for forgetting; in the NTK regime it extends to nonlinear networks with similar guarantees. Empirically, ORFit improves forgetting behavior under memory constraints and achieves competitive accuracy with reduced memory and computation compared to baselines, with applicability to deep networks via NTK-inspired analyses.

Abstract

While large machine learning models have shown remarkable performance in various domains, their training typically requires iterating for many passes over the training data. However, due to computational and memory constraints and potential privacy concerns, storing and accessing all the data is impractical in many real-world scenarios where the data arrives in a stream. In this paper, we investigate the problem of one-pass learning, in which a model is trained on sequentially arriving data without retraining on previous datapoints. Motivated by the demonstrated effectiveness of overparameterized models and the phenomenon of benign overfitting, we propose Orthogonal Recursive Fitting (ORFit), an algorithm for one-pass learning which seeks to perfectly fit each new datapoint while minimally altering the predictions on previous datapoints. ORFit updates the parameters in a direction orthogonal to past gradients, similar to orthogonal gradient descent (OGD) in continual learning. We show that, interestingly, ORFit's update leads to an operation similar to the recursive least-squares (RLS) algorithm in adaptive filtering but with significantly improved memory and computational efficiency, i.e., linear, instead of quadratic, in the number of parameters. To further reduce memory usage, we leverage the structure of the streaming data via an incremental principal component analysis (IPCA). We show that using the principal components is minimax optimal, i.e., it minimizes the worst-case forgetting of previous predictions for unknown future updates. Further, we prove that, for overparameterized linear models, the parameter vector obtained by ORFit matches what the standard multi-pass stochastic gradient descent (SGD) would converge to. Finally, we extend our results to the nonlinear setting for highly overparameterized models, relevant for deep learning.
Paper Structure (28 sections, 8 theorems, 51 equations, 3 figures, 4 algorithms)

This paper contains 28 sections, 8 theorems, 51 equations, 3 figures, 4 algorithms.

Key Result

Lemma 3.1

Consider a feature-based linear model $f(x;w)=\phi(x)^\top w\in\mathbb{R}$, and let $\tilde{g}$ be the projection defined by eq:ogd_projection of any vector $g\in\mathbb{R}^p$. Then, for any step size $\eta\in\mathbb{R}$, the new parameter $w' = w - \eta \tilde{g}$ preserves the predictions on the p

Figures (3)

  • Figure 1: An illustration of ORFit in the parameter space for a linear model. The parameter $w_{i-1}$ fits the previous datapoints $\{(x_k,y_k)\}_{k=1}^{i-1}$. The set $S$ (which is updated incrementally) consists of the directions moving towards which causes the most change in the predictions on previous data, and thus, moving orthogonal to $S$ keeps the predictions intact. Given a new datapoint $(x_i,y_i)$, projecting its corresponding gradient $g$ to the orthogonal complement of the subspace spanned by $S$ yields the new update direction $\tilde{g}$. ORFit finds a new parameter $w_i$ along the direction of $-\tilde{g}$ which fits the new datapoint $(x_i, y_i)$ within a single step, while still fitting the previous data $\{(x_k,y_k)\}_{k=1}^{i-1}$.
  • Figure 2: Results for the memory-restricted setting (§\ref{['subsec:exp_limit']}). (\ref{['subfig:limited_test']}) shows the evolution of the test errors measured after learning each datapoint, while (\ref{['subfig:limited_sample']}) shows the evolution of the prediction errors for a particular sample (the $16$-th example) after each iteration. The red dashed line indicates the step on which the sample is trained. The shades indicate the standard deviations over $10$ independent runs
  • Figure 3: Results for the setting without memory restriction (§\ref{['subsec:no_memory']}). (\ref{['subfig:full_test']}) shows the evolution of the test and train errors measured after each training step, while (\ref{['subfig:full_sample']}) shows the evolution of the prediction errors for a particular sample (the $11$-th example) after each iteration. The red dashed line indicates the step on which the sample is trained. The shades indicate the standard deviations over $10$ independent runs.

Theorems & Definitions (13)

  • Lemma 3.1
  • Remark 3.1: Computational overhead
  • Remark 3.2: Nonlinear models
  • Remark 3.3: Other forms of vector-output model
  • Proposition 4.1
  • Remark 4.1
  • Remark 4.2
  • Theorem 4.2
  • Corollary 4.3
  • Proposition 4.4
  • ...and 3 more