Table of Contents
Fetching ...

White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?

Yaodong Yu, Sam Buchanan, Druv Pai, Tianzhe Chu, Ziyang Wu, Shengbang Tong, Hao Bai, Yuexiang Zhai, Benjamin D. Haeffele, Yi Ma

TL;DR

This paper contends that a natural objective of representation learning is to compress and transform the distribution of the data towards a low-dimensional Gaussian mixture supported on incoherent subspaces, and develops a family of white-box transformer-like deep network architectures, named CRATE, which are mathematically fully interpretable.

Abstract

In this paper, we contend that a natural objective of representation learning is to compress and transform the distribution of the data, say sets of tokens, towards a low-dimensional Gaussian mixture supported on incoherent subspaces. The goodness of such a representation can be evaluated by a principled measure, called sparse rate reduction, that simultaneously maximizes the intrinsic information gain and extrinsic sparsity of the learned representation. From this perspective, popular deep network architectures, including transformers, can be viewed as realizing iterative schemes to optimize this measure. Particularly, we derive a transformer block from alternating optimization on parts of this objective: the multi-head self-attention operator compresses the representation by implementing an approximate gradient descent step on the coding rate of the features, and the subsequent multi-layer perceptron sparsifies the features. This leads to a family of white-box transformer-like deep network architectures, named CRATE, which are mathematically fully interpretable. We show, by way of a novel connection between denoising and compression, that the inverse to the aforementioned compressive encoding can be realized by the same class of CRATE architectures. Thus, the so-derived white-box architectures are universal to both encoders and decoders. Experiments show that these networks, despite their simplicity, indeed learn to compress and sparsify representations of large-scale real-world image and text datasets, and achieve performance very close to highly engineered transformer-based models: ViT, MAE, DINO, BERT, and GPT2. We believe the proposed computational framework demonstrates great potential in bridging the gap between theory and practice of deep learning, from a unified perspective of data compression. Code is available at: https://ma-lab-berkeley.github.io/CRATE .

White-Box Transformers via Sparse Rate Reduction: Compression Is All There Is?

TL;DR

This paper contends that a natural objective of representation learning is to compress and transform the distribution of the data towards a low-dimensional Gaussian mixture supported on incoherent subspaces, and develops a family of white-box transformer-like deep network architectures, named CRATE, which are mathematically fully interpretable.

Abstract

In this paper, we contend that a natural objective of representation learning is to compress and transform the distribution of the data, say sets of tokens, towards a low-dimensional Gaussian mixture supported on incoherent subspaces. The goodness of such a representation can be evaluated by a principled measure, called sparse rate reduction, that simultaneously maximizes the intrinsic information gain and extrinsic sparsity of the learned representation. From this perspective, popular deep network architectures, including transformers, can be viewed as realizing iterative schemes to optimize this measure. Particularly, we derive a transformer block from alternating optimization on parts of this objective: the multi-head self-attention operator compresses the representation by implementing an approximate gradient descent step on the coding rate of the features, and the subsequent multi-layer perceptron sparsifies the features. This leads to a family of white-box transformer-like deep network architectures, named CRATE, which are mathematically fully interpretable. We show, by way of a novel connection between denoising and compression, that the inverse to the aforementioned compressive encoding can be realized by the same class of CRATE architectures. Thus, the so-derived white-box architectures are universal to both encoders and decoders. Experiments show that these networks, despite their simplicity, indeed learn to compress and sparsify representations of large-scale real-world image and text datasets, and achieve performance very close to highly engineered transformer-based models: ViT, MAE, DINO, BERT, and GPT2. We believe the proposed computational framework demonstrates great potential in bridging the gap between theory and practice of deep learning, from a unified perspective of data compression. Code is available at: https://ma-lab-berkeley.github.io/CRATE .
Paper Structure (140 sections, 13 theorems, 283 equations, 27 figures, 17 tables, 6 algorithms)

This paper contains 140 sections, 13 theorems, 283 equations, 27 figures, 17 tables, 6 algorithms.

Key Result

Theorem 6

Suppose $\bm{Z}^{\ell}$ follows the noisy Gaussian codebook model model:gaussian_tokens_noise, with infinitesimal noise level $\sigma^{\ell} > 0$ and subspace memberships $s_{i}$ distributed as i.i.d. categorical random variables on the set of subspace indices $\set{1, \dots, K}$, independently of a

Figures (27)

  • Figure 1: Deep network layers $f^{\ell}$ which optimize the rate reduction. The separate components of the data distribution are transformed by the network operators to a configuration which maximizes the information gain. Here, $f$ may be realized by a ReduNet chan2021redunet, in which each layer implements a gradient descent iteration for optimizing the rate reduction.
  • Figure 2: Distribution flow in denoising-diffusion models. Starting with generic noise $\bm{z} = \widetilde{\bm{z}}^{0}$, the probability density of intermediate iterates is shaped towards the true distribution of $\widetilde{\bm{z}}^{L}$ locally and iteratively through the operators $g^{\ell}$, which use the score function $\nabla \log q^{\ell}$ at each layer $\ell$.
  • Figure 3: The optima of the sparse rate reduction. After pre-processing input data $\bm{X}$ into a sequence of tokens $\bm{Z}^{1}$, our crate network attempts to optimize the sparse rate reduction of the token features $\bm{Z} = \bm{Z}^{L + 1}$. The optimal representations, according to the sparse rate reduction objective, are linearized---having low-dimensional linear subspace structure---sparse---where the subspaces are axis-aligned---and compressed---adhering closely to that structure, with low or no noise. In the sequel, we discuss how crate achieves such representations via constructing each layer to iteratively optimize the sparse rate reduction.
  • Figure 4: The autoencoding process to be studied in \ref{['sec:encoding', 'sec:autoencoding']}. Each encoder layer $f^{\ell}$ and decoder layer $g^{L - \ell}$ are (partial) inverses of each other. Moreover, the overall representation $\bm{Z} = f(\bm{X})$ is parsimonious (compressed, linearized, and sparse, as in \ref{['sub:formulation']}), and the autoencoding is to be consistent in the sense that $\bm{X} \approx \widehat{\bm{X}}$.
  • Figure 5: Comparison of three sets of representations via rate reduction and sparsity. Each $S_i$ represents one linear subspace, and the number of blue balls represents the difference between the coding rates $\Delta R(\bm{Z} \nonscript\:\delimsize\vert \nonscript\: \mathopen{} \bm{U}_{[K]}) = R(\bm{Z}) - R^c(\bm{Z} \nonscript\:\delimsize\vert \nonscript\: \mathopen{} \bm{U}_{[K]})$.
  • ...and 22 more figures

Theorems & Definitions (25)

  • Remark 1: Connections to likelihood maximization and energy-based models
  • Remark 2: Intrinsic and extrinsic measures of goodness for the representations
  • Remark 3: Black-box representations learned through pretext tasks
  • Remark 4: Design choices in CRATE
  • Remark 5: The roles of the forward pass and backward propagation
  • Theorem 6: Informal version of \ref{['lem:inverse-term']} in \ref{['app:computations_rr_gradient']}
  • Remark 7: Lack of weight-tying in autoencoding architecture
  • Remark 8: Alternate instantiations of structured denoising-diffusion
  • Lemma 10
  • Lemma 11
  • ...and 15 more