Masked Completion via Structured Diffusion with White-Box Transformers
Druv Pai, Ziyang Wu, Sam Buchanan, Yaodong Yu, Yi Ma
TL;DR
This work tackles unsupervised representation learning with interpretable, structured models by introducing CRATE-MAE, a white-box transformer-like autoencoder derived from unrolling a sparse rate-reduction objective and linking compression with diffusion-inspired denoising. It develops a distributional signal model for token representations, constructs distributionally-invertible encoder and decoder layers based on MSSA and ISTA blocks, and frames a time-reversed diffusion process to enable deterministic autoencoding. Empirically, CRATE-MAE achieves competitive performance on large-scale imagery with roughly 30% of the parameters of standard masked autoencoders and reveals semantically meaningful, linearly structured representations and interpretable attention maps. The approach bridges diffusion, rate reduction, and transformer design, showing that principled white-box architectures can be effective for unsupervised learning and scalable vision tasks.
Abstract
Modern learning frameworks often train deep neural networks with massive amounts of unlabeled data to learn representations by solving simple pretext tasks, then use the representations as foundations for downstream tasks. These networks are empirically designed; as such, they are usually not interpretable, their representations are not structured, and their designs are potentially redundant. White-box deep networks, in which each layer explicitly identifies and transforms structures in the data, present a promising alternative. However, existing white-box architectures have only been shown to work at scale in supervised settings with labeled data, such as classification. In this work, we provide the first instantiation of the white-box design paradigm that can be applied to large-scale unsupervised representation learning. We do this by exploiting a fundamental connection between diffusion, compression, and (masked) completion, deriving a deep transformer-like masked autoencoder architecture, called CRATE-MAE, in which the role of each layer is mathematically fully interpretable: they transform the data distribution to and from a structured representation. Extensive empirical evaluations confirm our analytical insights. CRATE-MAE demonstrates highly promising performance on large-scale imagery datasets while using only ~30% of the parameters compared to the standard masked autoencoder with the same model configuration. The representations learned by CRATE-MAE have explicit structure and also contain semantic meaning. Code is available at https://github.com/Ma-Lab-Berkeley/CRATE .
