Table of Contents
Fetching ...

Recurrent Joint Embedding Predictive Architecture with Recurrent Forward Propagation Learning

Osvaldo M Velarde, Lucas C Parra

TL;DR

This work introduces a vision network inspired by biological vision that leverages a joint embedding predictive architecture incorporating recurrent gated circuits and introduces Recurrent-Forward Propagation, a learning algorithm that avoids biologically unrealistic backpropagation through time or memory-inefficient real-time recurrent learning.

Abstract

Conventional computer vision models rely on very deep, feedforward networks processing whole images and trained offline with extensive labeled data. In contrast, biological vision relies on comparatively shallow, recurrent networks that analyze sequences of fixated image patches, learning continuously in real-time without explicit supervision. This work introduces a vision network inspired by these biological principles. Specifically, it leverages a joint embedding predictive architecture incorporating recurrent gated circuits. The network learns by predicting the representation of the next image patch (fixation) based on the sequence of past fixations, a form of self-supervised learning. We show mathematical and empirically that the training algorithm avoids the problem of representational collapse. We also introduce \emph{Recurrent-Forward Propagation}, a learning algorithm that avoids biologically unrealistic backpropagation through time or memory-inefficient real-time recurrent learning. We show mathematically that the algorithm implements exact gradient descent for a large class of recurrent architectures, and confirm empirically that it learns efficiently. This paper focuses on these theoretical innovations and leaves empirical evaluation of performance in downstream tasks, and analysis of representational similarity with biological vision for future work.

Recurrent Joint Embedding Predictive Architecture with Recurrent Forward Propagation Learning

TL;DR

This work introduces a vision network inspired by biological vision that leverages a joint embedding predictive architecture incorporating recurrent gated circuits and introduces Recurrent-Forward Propagation, a learning algorithm that avoids biologically unrealistic backpropagation through time or memory-inefficient real-time recurrent learning.

Abstract

Conventional computer vision models rely on very deep, feedforward networks processing whole images and trained offline with extensive labeled data. In contrast, biological vision relies on comparatively shallow, recurrent networks that analyze sequences of fixated image patches, learning continuously in real-time without explicit supervision. This work introduces a vision network inspired by these biological principles. Specifically, it leverages a joint embedding predictive architecture incorporating recurrent gated circuits. The network learns by predicting the representation of the next image patch (fixation) based on the sequence of past fixations, a form of self-supervised learning. We show mathematical and empirically that the training algorithm avoids the problem of representational collapse. We also introduce \emph{Recurrent-Forward Propagation}, a learning algorithm that avoids biologically unrealistic backpropagation through time or memory-inefficient real-time recurrent learning. We show mathematically that the algorithm implements exact gradient descent for a large class of recurrent architectures, and confirm empirically that it learns efficiently. This paper focuses on these theoretical innovations and leaves empirical evaluation of performance in downstream tasks, and analysis of representational similarity with biological vision for future work.

Paper Structure

This paper contains 22 sections, 1 theorem, 42 equations, 4 figures.

Key Result

Theorem 1

In R-JEPA, when minimizing the square prediction error $E$ under a stable network dynamic, a linear representation predictor $W_{Gh}$ converges after repeated gradient decent iteration with weight decay to the following proportionality:

Figures (4)

  • Figure 1: Recurrent Joint Embedding Predictive Architecture (R-JEPA). For an input $x(t)$, the encoder generates a representation $h(t)$. Then, the predictor generates a prediction $\hat{h}(t+1)$ of the representation of the next input. The objective of the encoder and predictor is to minimize the prediction loss $\mathcal{L}_R$ in the embedding space $\mathcal{H}$. At the same time, the context vector $c(t)$ is used to predict the action/response $\hat{a}(t)$ to stimuli $x(t)$.
  • Figure 2: Recurrent Encoder. We implemented an encoder based on ResNet50 chen_exploring_2020 and Reciprocal Gated Circuit nayebi_recurrent_2022. The encoder is a hierarchical architecture: (1) The first five areas correspond to ResNet stages, (2) Area 6 works as a head in ResNet, and (3) Embedding is a linear projection. Areas 1-6 have recurrent units inside them.
  • Figure 3: R-JEPA avoids collapse. (a) A 2D projection of the trajectory of features $h(t)$ for different inputs $x(t)$ using PCA. (b) Distribution of eigenvalues of the correlation matrix of features, i.e. $H H^T$.
  • Figure 4: Prediction error as a function of time in video. Representation Loss indicates the ability of the recurrent network to predict the content of the next fixation (image patch). Curves are the average over 7000 fixation sequences in the test data. Network behavior is show before (Epoch=0) and after learning (Epoch=6). Time steps indicates the number of images patches (fixations) from the start of the recurrent iteration, i.e. the start of the test fixation sequences. The drop with time steps indicates that the network accumulates information in the context vector allowing it progressively improve its prediction

Theorems & Definitions (1)

  • Theorem 1