Table of Contents
Fetching ...

Randomized Matrix Sketching for Neural Network Training and Gradient Monitoring

Harbir Antil, Deepanshu Verma

TL;DR

The paper tackles the memory bottleneck of storing layer activations for gradient calculation and extended gradient monitoring in neural networks. It introduces an EMA-based, control-theoretic matrix sketching framework that maintains three per-layer sketches (X,Y,Z) with adaptive rank to enable memory-efficient gradient reconstruction and real-time monitoring. The approach is validated on MNIST, CIFAR-10, and physics-informed neural networks, showing controllable accuracy-memory tradeoffs and dramatic memory reductions for gradient monitoring (up to 99%) with minimal overhead in PINNs. The work provides theoretical bounds linking gradient reconstruction error to sketch rank and activation tail energy, and demonstrates practical utility in diagnosing training health and stability while enabling scalable diagnostics over long training horizons.

Abstract

Neural network training relies on gradient computation through backpropagation, yet memory requirements for storing layer activations present significant scalability challenges. We present the first adaptation of control-theoretic matrix sketching to neural network layer activations, enabling memory-efficient gradient reconstruction in backpropagation. This work builds on recent matrix sketching frameworks for dynamic optimization problems, where similar state trajectory storage challenges motivate sketching techniques. Our approach sketches layer activations using three complementary sketch matrices maintained through exponential moving averages (EMA) with adaptive rank adjustment, automatically balancing memory efficiency against approximation quality. Empirical evaluation on MNIST, CIFAR-10, and physics-informed neural networks demonstrates a controllable accuracy-memory tradeoff. We demonstrate a gradient monitoring application on MNIST showing how sketched activations enable real-time gradient norm tracking with minimal memory overhead. These results establish that sketched activation storage provides a viable path toward memory-efficient neural network training and analysis.

Randomized Matrix Sketching for Neural Network Training and Gradient Monitoring

TL;DR

The paper tackles the memory bottleneck of storing layer activations for gradient calculation and extended gradient monitoring in neural networks. It introduces an EMA-based, control-theoretic matrix sketching framework that maintains three per-layer sketches (X,Y,Z) with adaptive rank to enable memory-efficient gradient reconstruction and real-time monitoring. The approach is validated on MNIST, CIFAR-10, and physics-informed neural networks, showing controllable accuracy-memory tradeoffs and dramatic memory reductions for gradient monitoring (up to 99%) with minimal overhead in PINNs. The work provides theoretical bounds linking gradient reconstruction error to sketch rank and activation tail energy, and demonstrates practical utility in diagnosing training health and stability while enabling scalable diagnostics over long training horizons.

Abstract

Neural network training relies on gradient computation through backpropagation, yet memory requirements for storing layer activations present significant scalability challenges. We present the first adaptation of control-theoretic matrix sketching to neural network layer activations, enabling memory-efficient gradient reconstruction in backpropagation. This work builds on recent matrix sketching frameworks for dynamic optimization problems, where similar state trajectory storage challenges motivate sketching techniques. Our approach sketches layer activations using three complementary sketch matrices maintained through exponential moving averages (EMA) with adaptive rank adjustment, automatically balancing memory efficiency against approximation quality. Empirical evaluation on MNIST, CIFAR-10, and physics-informed neural networks demonstrates a controllable accuracy-memory tradeoff. We demonstrate a gradient monitoring application on MNIST showing how sketched activations enable real-time gradient norm tracking with minimal memory overhead. These results establish that sketched activation storage provides a viable path toward memory-efficient neural network training and analysis.

Paper Structure

This paper contains 30 sections, 3 theorems, 32 equations, 5 figures, 1 table, 2 algorithms.

Key Result

Lemma 4.1

The EMA sketch updates from Equations eq:X_ema_update--eq:Z_ema_update can be expressed as exponentially-weighted combinations of historical batch contributions: with analogous expressions for $\mathbf{Y}_s^{[\ell]}(n)$ and $\mathbf{Z}_s^{[\ell]}(n)$, where: represents the conceptual EMA-weighted activation matrix that is never explicitly formed but implicitly represented through the sketches.

Figures (5)

  • Figure 1: MNIST classification results: (Left) Peak memory usage comparison across methods. (Right) Training accuracy showing 3-5% performance gap with preserved convergence dynamics.
  • Figure 2: CIFAR-10 results for hybrid CNN-MLP architecture: (Left) Memory usage with sketched fully-connected layers. (Right) Accuracy preservation demonstrating effective integration with convolutional components.
  • Figure 3: PINN training results with sketch-based monitoring: (Left) Peak memory usage showing minimal overhead (0.57 MB) from sketch storage for monitoring. (Right) Total loss convergence demonstrating identical training dynamics across standard backpropagation and sketch-based monitoring variants, validating that comprehensive gradient monitoring can be achieved without compromising physics constraint satisfaction. All methods achieve $L^2$ relative error of 0.31 on testing points.
  • Figure 4: PINN solution quality comparison: (Left) Exact solution, (Right, top to bottom) Standard backpropagation, Fixed-rank sketching, and Adaptive sketching. Each right column shows the predicted solution and the corresponding absolute error. All methods achieve $L^2$ relative error of 0.31 on testing points.
  • Figure 5: Gradient monitoring demonstration comparing healthy and problematic network configurations on MNIST. Both sixteen-layer networks (1024 neurons in each hidden layer) use sketch rank $r=4$.

Theorems & Definitions (6)

  • Lemma 4.1: EMA Sketch Temporal Expansion
  • proof
  • Theorem 4.2: EMA Activation Matrix Reconstruction Error
  • proof
  • Theorem 4.3: Gradient Reconstruction Error via EMA Approximation
  • proof