Table of Contents
Fetching ...

Deep Companion Learning: Enhancing Generalization Through Historical Consistency

Ruizhao Zhu, Venkatesh Saligrama

TL;DR

This work tackles generalization in supervised learning by addressing SGD variability during training. It introduces Deep Companion Learning (DCL), which uses a deep-companion network $\omega$ to forecast logits on new inputs based on historical deployments $\theta_t$ and enforces predictive consistency via a data-dependent regularizer that aligns current predictions with this forecast. Empirically, DCL yields state-of-the-art results across CIFAR-100, Tiny-ImageNet, and ImageNet-1K with diverse backbones, often matching or exceeding pre-trained models while training from scratch and reducing computational demands. The approach is versatile, extending to fine-tuning, semi-supervised learning, self-supervised pretraining, and knowledge distillation, with ablations validating the choice of $\alpha$, the use of MSE distance, and the feasibility of smaller companions and reduced data.

Abstract

We propose Deep Companion Learning (DCL), a novel training method for Deep Neural Networks (DNNs) that enhances generalization by penalizing inconsistent model predictions compared to its historical performance. To achieve this, we train a deep-companion model (DCM), by using previous versions of the model to provide forecasts on new inputs. This companion model deciphers a meaningful latent semantic structure within the data, thereby providing targeted supervision that encourages the primary model to address the scenarios it finds most challenging. We validate our approach through both theoretical analysis and extensive experimentation, including ablation studies, on a variety of benchmark datasets (CIFAR-100, Tiny-ImageNet, ImageNet-1K) using diverse architectural models (ShuffleNetV2, ResNet, Vision Transformer, etc.), demonstrating state-of-the-art performance.

Deep Companion Learning: Enhancing Generalization Through Historical Consistency

TL;DR

This work tackles generalization in supervised learning by addressing SGD variability during training. It introduces Deep Companion Learning (DCL), which uses a deep-companion network to forecast logits on new inputs based on historical deployments and enforces predictive consistency via a data-dependent regularizer that aligns current predictions with this forecast. Empirically, DCL yields state-of-the-art results across CIFAR-100, Tiny-ImageNet, and ImageNet-1K with diverse backbones, often matching or exceeding pre-trained models while training from scratch and reducing computational demands. The approach is versatile, extending to fine-tuning, semi-supervised learning, self-supervised pretraining, and knowledge distillation, with ablations validating the choice of , the use of MSE distance, and the feasibility of smaller companions and reduced data.

Abstract

We propose Deep Companion Learning (DCL), a novel training method for Deep Neural Networks (DNNs) that enhances generalization by penalizing inconsistent model predictions compared to its historical performance. To achieve this, we train a deep-companion model (DCM), by using previous versions of the model to provide forecasts on new inputs. This companion model deciphers a meaningful latent semantic structure within the data, thereby providing targeted supervision that encourages the primary model to address the scenarios it finds most challenging. We validate our approach through both theoretical analysis and extensive experimentation, including ablation studies, on a variety of benchmark datasets (CIFAR-100, Tiny-ImageNet, ImageNet-1K) using diverse architectural models (ShuffleNetV2, ResNet, Vision Transformer, etc.), demonstrating state-of-the-art performance.
Paper Structure (32 sections, 8 equations, 6 figures, 12 tables, 1 algorithm)

This paper contains 32 sections, 8 equations, 6 figures, 12 tables, 1 algorithm.

Figures (6)

  • Figure 1: Method Overview. At iteration $t$, we optimize the instantaneous model $\boldsymbol{\theta}_t$ (model eventually deployed upon training) with standard cross entropy loss and a regularizer enforcing consistency with model $\boldsymbol{\omega}$. Model $\boldsymbol{\omega}$ is recursively updated by approximating predictions from its previous embodiment and the predictions of current model $\boldsymbol{\theta}_t$. (Right) Probability of class as the top non-target class is shown. The companion model helps narrow down the top non-target classes as tiger and lion for class leopard. In the initial training stage, the top non-target of deployed model logits are more randomly distributed with some irrelevant classes. The companion model can help capture a general semantic structure of the dataset.
  • Figure 2: Higher Top Non-target Consistency Indicates Better Generalization. We visualize training top non-target consistency and test accuracy with CIFAR-100 across different models. (a) Larger CE models have better generalization while having larger top non-target consistency. This indicates a positive correlation between top non-target consistency and test accuracy. (b) Across different architectures we see a consistent correlation of improved DCL test accuracy with increasing top non-target consistency. (c) DCL chooses Tiger and Lion most frequently as the top non-target classes of Leopard during training while CE exhibits inconsistent patterns.
  • Figure 3: Non-target Perplexity for Different Classes. The most (black text) and second most (green text) frequent classes as the top non-target class are shown on each bar. DCL can reduce perplexity over CE baseline.
  • Figure 4: Semi-Supervised Learning. DCL demonstrates superior performance over FixMatch, particularly with fewer labels.
  • Figure 5: Comparison of Model Variation and Test Accuracy along Training Trajectory on CIFAR-100 with ResNet18. The left section displays t-SNE visualizations of output logits for 50 test data samples, illustrating reduced variation in logits output across various training stages using DCL. In the plot, $d$ is the average distance over data representation in the logit space. The right section presents the progression of test accuracy during training, DCL predictions show smaller variation and attain better generalization.
  • ...and 1 more figures