ORFit: One-Pass Learning via Bridging Orthogonal Gradient Descent and Recursive Least-Squares
Youngjae Min, Namhoon Cho, Navid Azizan
TL;DR
ORFit tackles one-pass learning for streaming data by updating parameters orthogonally to past gradients while interpolating the new datapoint, ensuring minimal forgetting. It combines an orthogonal update with an IPCA-based memory budget to maintain linear memory in the number of parameters and extends to batch updates. Theoretically, ORFit matches the multi-pass SGD solution for overparameterized linear models and its principal directions are minimax-optimal for forgetting; in the NTK regime it extends to nonlinear networks with similar guarantees. Empirically, ORFit improves forgetting behavior under memory constraints and achieves competitive accuracy with reduced memory and computation compared to baselines, with applicability to deep networks via NTK-inspired analyses.
Abstract
While large machine learning models have shown remarkable performance in various domains, their training typically requires iterating for many passes over the training data. However, due to computational and memory constraints and potential privacy concerns, storing and accessing all the data is impractical in many real-world scenarios where the data arrives in a stream. In this paper, we investigate the problem of one-pass learning, in which a model is trained on sequentially arriving data without retraining on previous datapoints. Motivated by the demonstrated effectiveness of overparameterized models and the phenomenon of benign overfitting, we propose Orthogonal Recursive Fitting (ORFit), an algorithm for one-pass learning which seeks to perfectly fit each new datapoint while minimally altering the predictions on previous datapoints. ORFit updates the parameters in a direction orthogonal to past gradients, similar to orthogonal gradient descent (OGD) in continual learning. We show that, interestingly, ORFit's update leads to an operation similar to the recursive least-squares (RLS) algorithm in adaptive filtering but with significantly improved memory and computational efficiency, i.e., linear, instead of quadratic, in the number of parameters. To further reduce memory usage, we leverage the structure of the streaming data via an incremental principal component analysis (IPCA). We show that using the principal components is minimax optimal, i.e., it minimizes the worst-case forgetting of previous predictions for unknown future updates. Further, we prove that, for overparameterized linear models, the parameter vector obtained by ORFit matches what the standard multi-pass stochastic gradient descent (SGD) would converge to. Finally, we extend our results to the nonlinear setting for highly overparameterized models, relevant for deep learning.
