Scaling White-Box Transformers for Vision
Jinrui Yang, Xianhang Li, Druv Pai, Yuyin Zhou, Yi Ma, Yaodong Yu, Cihang Xie
TL;DR
This paper demonstrates the scalable training of CRATE-α, a white-box vision transformer built on the CRATE framework, by introducing an overcomplete and decoupled dictionary-based sparse coding block and a residual connection. The combined architectural edits and a light training recipe enable substantial scaling from Base to Huge, achieving 85.1% top-1 on ImageNet-1K with IN-21K pretraining and 72.3% zero-shot accuracy with DataComp1B, while preserving, and often improving, semantic interpretability such as zero-shot segmentation. The work provides a principled path to scaling mathematically interpretable models, offering competitive performance with ViTs under comparable compute and enabling broader applications in vision-language pretraining and downstream tasks. Overall, CRATE-α advances scalable, interpretable deep nets by balancing compression, sparsity, and expressive capacity through unrolled optimization grounded in Sparse Rate Reduction.
Abstract
CRATE, a white-box transformer architecture designed to learn compressed and sparse representations, offers an intriguing alternative to standard vision transformers (ViTs) due to its inherent mathematical interpretability. Despite extensive investigations into the scaling behaviors of language and vision transformers, the scalability of CRATE remains an open question which this paper aims to address. Specifically, we propose CRATE-$α$, featuring strategic yet minimal modifications to the sparse coding block in the CRATE architecture design, and a light training recipe designed to improve the scalability of CRATE. Through extensive experiments, we demonstrate that CRATE-$α$ can effectively scale with larger model sizes and datasets. For example, our CRATE-$α$-B substantially outperforms the prior best CRATE-B model accuracy on ImageNet classification by 3.7%, achieving an accuracy of 83.2%. Meanwhile, when scaling further, our CRATE-$α$-L obtains an ImageNet classification accuracy of 85.1%. More notably, these model performance improvements are achieved while preserving, and potentially even enhancing the interpretability of learned CRATE models, as we demonstrate through showing that the learned token representations of increasingly larger trained CRATE-$α$ models yield increasingly higher-quality unsupervised object segmentation of images. The project page is https://rayjryang.github.io/CRATE-alpha/.
