Table of Contents
Fetching ...

Learning predictable and robust neural representations by straightening image sequences

Xueyan Niu, Cristina Savin, Eero P. Simoncelli

TL;DR

The paper investigates learning predictive neural representations by promoting straightened temporal trajectories in video-derived embeddings. It introduces a parameter-free straightening objective combined with whitening regularizers, and demonstrates its effectiveness on synthetic sequential data, yielding representations that preserve dynamic attributes and enable accurate linear extrapolation. The authors show that straightened representations are more robust to noise and adversarial perturbations than invariance-based SSL methods, and that straightening can boost robustness when used as a regularizer with other SSL objectives. They also provide geometric insight into how straightening shapes trajectory structure to facilitate class separability, and discuss extensions to multi-timescale and hierarchical prediction. Overall, straightening emerges as a practical, robust principle for self-supervised learning from temporal visual inputs with broad applicability to other models and data domains.

Abstract

Prediction is a fundamental capability of all living organisms, and has been proposed as an objective for learning sensory representations. Recent work demonstrates that in primate visual systems, prediction is facilitated by neural representations that follow straighter temporal trajectories than their initial photoreceptor encoding, which allows for prediction by linear extrapolation. Inspired by these experimental findings, we develop a self-supervised learning (SSL) objective that explicitly quantifies and promotes straightening. We demonstrate the power of this objective in training deep feedforward neural networks on smoothly-rendered synthetic image sequences that mimic commonly-occurring properties of natural videos. The learned model contains neural embeddings that are predictive, but also factorize the geometric, photometric, and semantic attributes of objects. The representations also prove more robust to noise and adversarial attacks compared to previous SSL methods that optimize for invariance to random augmentations. Moreover, these beneficial properties can be transferred to other training procedures by using the straightening objective as a regularizer, suggesting a broader utility for straightening as a principle for robust unsupervised learning.

Learning predictable and robust neural representations by straightening image sequences

TL;DR

The paper investigates learning predictive neural representations by promoting straightened temporal trajectories in video-derived embeddings. It introduces a parameter-free straightening objective combined with whitening regularizers, and demonstrates its effectiveness on synthetic sequential data, yielding representations that preserve dynamic attributes and enable accurate linear extrapolation. The authors show that straightened representations are more robust to noise and adversarial perturbations than invariance-based SSL methods, and that straightening can boost robustness when used as a regularizer with other SSL objectives. They also provide geometric insight into how straightening shapes trajectory structure to facilitate class separability, and discuss extensions to multi-timescale and hierarchical prediction. Overall, straightening emerges as a practical, robust principle for self-supervised learning from temporal visual inputs with broad applicability to other models and data domains.

Abstract

Prediction is a fundamental capability of all living organisms, and has been proposed as an objective for learning sensory representations. Recent work demonstrates that in primate visual systems, prediction is facilitated by neural representations that follow straighter temporal trajectories than their initial photoreceptor encoding, which allows for prediction by linear extrapolation. Inspired by these experimental findings, we develop a self-supervised learning (SSL) objective that explicitly quantifies and promotes straightening. We demonstrate the power of this objective in training deep feedforward neural networks on smoothly-rendered synthetic image sequences that mimic commonly-occurring properties of natural videos. The learned model contains neural embeddings that are predictive, but also factorize the geometric, photometric, and semantic attributes of objects. The representations also prove more robust to noise and adversarial attacks compared to previous SSL methods that optimize for invariance to random augmentations. Moreover, these beneficial properties can be transferred to other training procedures by using the straightening objective as a regularizer, suggesting a broader utility for straightening as a principle for robust unsupervised learning.

Paper Structure

This paper contains 23 sections, 4 equations, 8 figures.

Figures (8)

  • Figure 1: Learning straightened representations. A. Illustration of temporal trajectories of four translating digit sequences, in the space of pixel intensities (left), and in a straightened representation (right). Color indicates digit identity. B. The actual two-dimensional t-SNE rendering of 20 temporal trajectories for each of the ten translating digits from our model. Initial (pixel intensity) representation is highly curved and entangled (left). Although the straightening objective is unsupervised (no object labels), the learned representation clearly isolates the trajectories corresponding to different digits (right).
  • Figure 2: Straightening and its benefits, evaluated on a network trained on sequential MNIST. A. Three example sequences, illustrating the three geometric transformations. B. Emergence of straightness throughout layers of network computation. C. Accuracy in decoding various (untrained) variables from the network responses (top). Accuracy in predicting variables at the next time step (bottom). Identity decoding was not considered for prediction as it is constant over the sequence. D. Prediction capabilities of the network. Top: example sequence, with dilating/contracting digit. Middle: reconstructions from simultaneous representation. Bottom: predictions (linear extrapolation) based on the representation at the previous two time steps.
  • Figure 3: Geometric properties of the straightened representation. Panels A-E show histograms of cosine similarity (normalized dot product) between pairs of difference vectors, $z_t - z_{t-1}$. Insets show example trajectories in each scenario, where color indicates digit identity. A. same digit and transformation type; B. same digit and different transformation; C. different digit and same transformation; D. different digit and transformation; E. all difference vectors vs. digit classifier vectors. F. Average effective dimensionality, measured with participation ratio, of the set of responses $z_t$ in each group.
  • Figure 4: Effect of straightening on representational robustness. A. Two example synthetic sequences from on sequential CIFAR-10 dataset. Top: translation and color shift. Bottom: rescaling (contraction) and color shift, last frame randomly grayscaled. B. Emergence of straightness throughout layers of network computation. Top arrows mark the stages of representation directly targeted for straightening (blue) and invariance (orange). C. Example sequences illustrating successes (left) and failures (right) of straightening. Numbers indicate straightness level $\in [-1, 1]$. D. Noise robustness: classification accuracy as a function of the amplitude of additive Gaussian noise injected in the input. E. Adversarial robustness: classification accuracy as a function of attack budget (see text). F. Relative classification accuracy of straightened network compared to invariance-trained network for various degradations. Color indicates the objective with better performance.
  • Figure 5: Augmentation of other SSL objectives with a straightening regularizer. A. Straightness of representations learned by four different SSL objectives (gray), and their augmentation with a straightening regularizer (blue). B. CIFAR-10 classification accuracy as a function of adversarial attack budget, for the original and straightening-regularized version, for the same four SSL objectives.
  • ...and 3 more figures