Table of Contents
Fetching ...

Residual Connections Harm Generative Representation Learning

Xiao Zhang, Ruoxi Jiang, William Gao, Rebecca Willett, Michael Maire

TL;DR

This work questions the ubiquity of unmodified residual shortcuts in self-supervised generative learning, arguing that identity paths can suppress abstraction by echoing shallow features. It introduces depth-dependent decayed identity shortcuts, formalized by ${\bm{x}}_{l+1} = {\alpha}_l {\bm{x}}_l + f_{\mathbf{\theta}_l}({\bm{x}}_l)$ with ${\alpha}_l = 1-\delta_{\alpha} l$ and ${\alpha_L^{\rm eff}} = \prod_{l=1}^L {\alpha_l}$, controlled by a single hyperparameter ${\alpha_{\min}}$. In MAE with ViT-B/16, this yields LP accuracy of ${\rm LP}=72.7\%$ and ${\rm KNN}=63.9\%$, a substantial improvement over the baseline, while diffusion models show concurrent gains in representation quality and generation. The results reveal a link between improved abstractions and a low-rank inductive bias, suggesting that carefully decaying skip connections can enhance unsupervised learning and generative modeling without extra parameters.

Abstract

We show that introducing a weighting factor to reduce the influence of identity shortcuts in residual networks significantly enhances semantic feature learning in generative representation learning frameworks, such as masked autoencoders (MAEs) and diffusion models. Our modification notably improves feature quality, raising ImageNet-1K K-Nearest Neighbor accuracy from 27.4% to 63.9% and linear probing accuracy from 67.8% to 72.7% for MAEs with a ViT-B/16 backbone, while also enhancing generation quality in diffusion models. This significant gap suggests that, while residual connection structure serves an essential role in facilitating gradient propagation, it may have a harmful side effect of reducing capacity for abstract learning by virtue of injecting an echo of shallower representations into deeper layers. We ameliorate this downside via a fixed formula for monotonically decreasing the contribution of identity connections as layer depth increases. Our design promotes the gradual development of feature abstractions, without impacting network trainability. Analyzing the representations learned by our modified residual networks, we find correlation between low effective feature rank and downstream task performance.

Residual Connections Harm Generative Representation Learning

TL;DR

This work questions the ubiquity of unmodified residual shortcuts in self-supervised generative learning, arguing that identity paths can suppress abstraction by echoing shallow features. It introduces depth-dependent decayed identity shortcuts, formalized by with and , controlled by a single hyperparameter . In MAE with ViT-B/16, this yields LP accuracy of and , a substantial improvement over the baseline, while diffusion models show concurrent gains in representation quality and generation. The results reveal a link between improved abstractions and a low-rank inductive bias, suggesting that carefully decaying skip connections can enhance unsupervised learning and generative modeling without extra parameters.

Abstract

We show that introducing a weighting factor to reduce the influence of identity shortcuts in residual networks significantly enhances semantic feature learning in generative representation learning frameworks, such as masked autoencoders (MAEs) and diffusion models. Our modification notably improves feature quality, raising ImageNet-1K K-Nearest Neighbor accuracy from 27.4% to 63.9% and linear probing accuracy from 67.8% to 72.7% for MAEs with a ViT-B/16 backbone, while also enhancing generation quality in diffusion models. This significant gap suggests that, while residual connection structure serves an essential role in facilitating gradient propagation, it may have a harmful side effect of reducing capacity for abstract learning by virtue of injecting an echo of shallower representations into deeper layers. We ameliorate this downside via a fixed formula for monotonically decreasing the contribution of identity connections as layer depth increases. Our design promotes the gradual development of feature abstractions, without impacting network trainability. Analyzing the representations learned by our modified residual networks, we find correlation between low effective feature rank and downstream task performance.
Paper Structure (21 sections, 6 equations, 8 figures, 5 tables)

This paper contains 21 sections, 6 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: We design decayed identity shortcuts (Figure \ref{['fig:teaser']}), a variant of residual connections, to facilitate self-supervised representation learning in generative model. Compared to standard residual connections, our approach yields superior abstract semantic features (left, visualized using zhang2024deciphering's approach), whose leading components pop out object instances and classes. Quantitative evaluation shows our architecture encourages lower feature rank and learns better feature representation for both MAE and diffusion models (middle), along with enhanced generation quality for diffusion models (right). These improvements require no additional learnable parameters.
  • Figure 2: Our decayed identity shortcuts introduce a depth-dependent scaling factor to shortcuts in a residual network, thereby modulating the contribution of preceding layers and fostering greater abstraction in deeper layers. A simple schema for controlling decay factor $\alpha$ suffices to improve feature learning in both MAEs and diffusion models, as well as diffusion model generation quality.
  • Figure 3: Visualize learned representations using zhang2024deciphering without cherry-picking. We project the learned representations onto a 3-channel feature map, visualized as RGB images. Our method learns more abstract and semantically consistent representations compared to the baseline MAE. This visual comparison is further supported by benchmarking on unsupervised semantic segmentation tasks, where our approach achieves better results (10.4 mIoU) compared to the baseline MAE (4.1 mIoU).
  • Figure 4: For MAE pretrained on ImageNet-100, we present visualizations of (a) the training dynamics of the effective rank for different values of $\alpha_{\rm{min}}$, (b) the linear probing accuracy for various $\alpha_{\rm{min}}$, demonstrating that a lower effective feature rank is associated with better performance.
  • Figure 5: We present our enhanced UNet Transformer architecture for Masked Auto-encoder. (1) Left: Our customized encoder blocks, equipped with our proposed decay identity shortcuts. (2) Middle: Standard transformer blocks as the decoder blocks. (3) Right: We incorporate the decay identity shortcuts exclusively within the encoder blocks of our UNet transformer and employ standard transformer blocks for the decoder. To support abstract representation learning at the bottleneck, i.e., the last layer of the Encoder 12, we adopt the UNet ronneberger2015u architecture and create skip connections that transmit every other encoder feature directly to the decoder.
  • ...and 3 more figures