Table of Contents
Fetching ...

Closed-Form Feedback-Free Learning with Forward Projection

Robert O'Shea, Bipin Rajendran

TL;DR

Forward Projection is proposed, a randomised closed-form training method requiring only a single forward pass over the dataset without retrograde communication, which achieves generalisation comparable to gradient descent-based local learning methods while requiring only a single forward propagation step, yielding significant training speedup.

Abstract

State-of-the-art backpropagation-free learning methods employ local error feedback to direct iterative optimisation via gradient descent. Here, we examine the more restrictive setting where retrograde communication from neuronal outputs is unavailable for pre-synaptic weight optimisation. We propose Forward Projection (FP), a randomised closed-form training method requiring only a single forward pass over the dataset without retrograde communication. FP generates target values for pre-activation membrane potentials through randomised nonlinear projections of pre-synaptic inputs and labels. Local loss functions are optimised using closed-form regression without feedback from downstream layers. A key advantage is interpretability: membrane potentials in FP-trained networks encode information interpretable layer-wise as label predictions. Across several biomedical datasets, FP achieves generalisation comparable to gradient descent-based local learning methods while requiring only a single forward propagation step, yielding significant training speedup. In few-shot learning tasks, FP produces more generalisable models than backpropagation-optimised alternatives, with local interpretation functions successfully identifying clinically salient diagnostic features.

Closed-Form Feedback-Free Learning with Forward Projection

TL;DR

Forward Projection is proposed, a randomised closed-form training method requiring only a single forward pass over the dataset without retrograde communication, which achieves generalisation comparable to gradient descent-based local learning methods while requiring only a single forward propagation step, yielding significant training speedup.

Abstract

State-of-the-art backpropagation-free learning methods employ local error feedback to direct iterative optimisation via gradient descent. Here, we examine the more restrictive setting where retrograde communication from neuronal outputs is unavailable for pre-synaptic weight optimisation. We propose Forward Projection (FP), a randomised closed-form training method requiring only a single forward pass over the dataset without retrograde communication. FP generates target values for pre-activation membrane potentials through randomised nonlinear projections of pre-synaptic inputs and labels. Local loss functions are optimised using closed-form regression without feedback from downstream layers. A key advantage is interpretability: membrane potentials in FP-trained networks encode information interpretable layer-wise as label predictions. Across several biomedical datasets, FP achieves generalisation comparable to gradient descent-based local learning methods while requiring only a single forward propagation step, yielding significant training speedup. In few-shot learning tasks, FP produces more generalisable models than backpropagation-optimised alternatives, with local interpretation functions successfully identifying clinically salient diagnostic features.

Paper Structure

This paper contains 21 sections, 72 equations, 7 figures, 5 tables.

Figures (7)

  • Figure 1: Graphical overview of forward projection. A: Forward projection algorithm for fitting layer weights $\mathbf{W}_1, \ldots, \mathbf{W}_l$ to model labels $\mathbf{Y}$ from data $\mathbf{X}$. B: Procedure for generating the $l$-th layer target potentials $\tilde{\mathbf{Z}}_l$. Pre-synaptic inputs $\mathbf{A}_{l-1}$ and labels $\mathbf{Y}$ are projected with fixed matrices $\mathbf{Q}_l$ and $\mathbf{U}_l$, respectively, before applying non-linearity $g_l$. C: Optimising $\mathbf{W}_l$ to predict $\tilde{\mathbf{Z}}_l$ from $\mathbf{A}_{l-1}$ by ridge regression with penalty $\lambda$. D: Interpreting membrane potentials $\mathbf{z}_l$ as a local label prediction $\hat{\mathbf{y}}_l$ given pre-synaptic inputs $\mathbf{a}_{l-1}$ and projection matrices $\mathbf{Q}_l$ and $\mathbf{U}_l^+$, where $\mathbf{U}_l^+$ is the pseudo-inverse of $\mathbf{U}_l$.
  • Figure 2: Performance of Forward Projection with backpropagation and local learning approaches. A: Comparison of feedback-free fitting methods on FMNIST. MLP architectures had 1000 neurons in the first and second layers and 100, 200, 400 or 800 neurons in the final hidden layer. B-D: Test performance of few-shot trained 2D-CNN models. Mean test AUC is reported over 50 few-shot training experiments. B: Chest X-ray (CXR) task. C: Optical Coherence Tomography (OCT) task. D: CIFAR2 task in which models were required to classify the first two classes (aeroplane and automobile). Models were fitted with $N\in\{5,10,15,20,30,40,50\}$ training samples from each class in CXR and OCT tasks and $N\in \{25,50,75,100\}$ samples per class for the CIFAR2 task. Predictive Coding and Difference Target Propagation are plotted separately in Supplementary Figure \ref{['fig: fewshot extra methods']}. BP: backpropagation; FF: Forward-Forward; FP: Forward Projection; LS: Local Supervision; RF: Random Features.
  • Figure 3: Forward Projection layer interpretions for electrocardiogram analysis. A: Visualisation of layer explanations over time in a 1D-convolutional neural network trained by Forward Projection to detect myocardial infarction (MI) in electrocardiograms (ECGs) from PTBXL data. Patients A, B, C, E (diagnosed MI) and patient D (no disease) were extracted from test data. Explanations were extracted from the second, fourth and sixth convolutional layers ($\hat{y}_2,\hat{y}_4,\hat{y}_6$) using equation \ref{['eq:layer explanation']}. Explanations increase with MI features (highlighted in red), including ST-segment depression (Patient A), ST-segment elevation (Patient B) and QRS widening with T-wave inversion (Patient C). B: Comparison of Forward Projection with GradCAM. Top: ECG data from Patient E (diagnosed MI), showing ST-segment elevation. Middle: Sixth convolutional layer explanation ($\hat{y}_6$) from a model trained by Forward Projection. Below: GradCAM output for the sixth convolutional layer of a model trained by backpropagation.
  • Figure 4: Visualisation of layer explanations over space in 2D-CNNs trained by Forward Projection to detect choroid neovascularization (CNV) in the OCT task. Ensemble average of five models shown. Patients A-C (diagnosed CNV) and patient D (no disease) were extracted from test data. Explanations were extracted from the second, fourth and sixth convolutional layers ($\hat{y}_2,\hat{y}_4,\hat{y}_6$), using \ref{['eq:layer explanation']}. CNV heat-maps demonstrate high values (red) over CNV features, including retinal/subretinal fluid (Patients A-C) and hard exudates (Patient B), and fibrosis (Patient C), with low values (blue) over healthy retina (Patient D).
  • Figure 5: Training Procedures for Forward Projection, Local Supervision, Forward Forward and Backpropagation learning algorithms for the $l$-th hidden layer. A: Forward Projection generates target matrix $\hat{\mathbf{Z}}_l$ by projecting pre-synaptic inputs $\mathbf{A}_{l-1}$ by $\mathbf{Q}_l$ and labels $\mathbf{Y}$ by $\mathbf{U}_l$. Weights $\mathbf{W}_l$ are fitted by regression, generating membrane potential $\mathbf{Z}_l=\mathbf{A}_{l-1}\mathbf{W}_l$. B: In Local Supervision, an auxiliary prediction $\hat{\mathbf{Y}}_l$ is generated as a projection of the post-synaptic outputs $\mathbf{A}_l$, and $\mathbf{W}_l$ is updated by a short backward pass (red arrows). C: In Forward-Forward, "positive" and "negative" pre-synaptic activities, ${\mathbf{A}}_l$ and $\check{\mathbf{A}}_l$, are generated from true and spurious data-label pairs, respectively. $\mathbf{W}_l$ is updated to maximise positive activity whilst minimising negative activity. D: In backpropagation, $\mathbf{W}_l$ is updated along its gradient with respect to the backpropagated error.
  • ...and 2 more figures