Table of Contents
Fetching ...

Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization

Ramnath Kumar, Kushal Majmundar, Dheeraj Nagaraj, Arun Sai Suggala

TL;DR

Re-weighted Gradient Descent is presented, a novel optimization technique that improves the performance of deep neural networks through dynamic sample re-weighting using insights from distributionally robust optimization (DRO) with Kullback-Leibler divergence.

Abstract

We present Re-weighted Gradient Descent (RGD), a novel optimization technique that improves the performance of deep neural networks through dynamic sample re-weighting. Leveraging insights from distributionally robust optimization (DRO) with Kullback-Leibler divergence, our method dynamically assigns importance weights to training data during each optimization step. RGD is simple to implement, computationally efficient, and compatible with widely used optimizers such as SGD and Adam. We demonstrate the effectiveness of RGD on various learning tasks, including supervised learning, meta-learning, and out-of-domain generalization. Notably, RGD achieves state-of-the-art results on diverse benchmarks, with improvements of +0.7% on DomainBed, +1.44% on tabular classification, \textcolor{blue}+1.94% on GLUE with BERT, and +1.01% on ImageNet-1K with ViT.

Stochastic Re-weighted Gradient Descent via Distributionally Robust Optimization

TL;DR

Re-weighted Gradient Descent is presented, a novel optimization technique that improves the performance of deep neural networks through dynamic sample re-weighting using insights from distributionally robust optimization (DRO) with Kullback-Leibler divergence.

Abstract

We present Re-weighted Gradient Descent (RGD), a novel optimization technique that improves the performance of deep neural networks through dynamic sample re-weighting. Leveraging insights from distributionally robust optimization (DRO) with Kullback-Leibler divergence, our method dynamically assigns importance weights to training data during each optimization step. RGD is simple to implement, computationally efficient, and compatible with widely used optimizers such as SGD and Adam. We demonstrate the effectiveness of RGD on various learning tasks, including supervised learning, meta-learning, and out-of-domain generalization. Notably, RGD achieves state-of-the-art results on diverse benchmarks, with improvements of +0.7% on DomainBed, +1.44% on tabular classification, \textcolor{blue}+1.94% on GLUE with BERT, and +1.01% on ImageNet-1K with ViT.
Paper Structure (58 sections, 3 theorems, 25 equations, 6 figures, 29 tables, 1 algorithm)

This paper contains 58 sections, 3 theorems, 25 equations, 6 figures, 29 tables, 1 algorithm.

Key Result

Proposition 3.1

shapiro2017distributionally Consider DRO with KL-divergence-based uncertainty set. Then $\min_{\theta\in\Theta}\widehat{R}_{D, n}$ can be rewritten as: for some constant $\gamma>0$ that is independent of $\theta$.

Figures (6)

  • Figure 1: Ablation of scaling and clipping factor of RGD training regime on the Imagenet dataset with a ViT-S backbone.
  • Figure 2: Figure illustrating Distributionally Robust Optimization (DRO). In contrast to ERM which learns a model that minimizes expected loss over original data distribution, DRO learns a model that performs well simultaneously on several perturbed versions of the original data distribution.
  • Figure 3: Figure illustrating the intuitive idea behind the working of RGD in the binary classification setting. RGD upweights the points which have high losses - points which have been misclassified by the model.
  • Figure 4: Figure \ref{['fig:mse_loss_lr4']} showing the convergence of SGD, $\textsc{RGD}$ algorithms for estimating the linear regression parameter. The L2 distance between the iterates ($\theta$) and the true parameter $\theta^*$ is studied in \ref{['fig:lr_4_good', 'fig:lr_4_bad']}. Specifically, Figure \ref{['fig:lr_4_good']} depicts the squared error in the frequently appearing directions, where all the techniques perform equally well. However, when it comes to learning rare directions, our proposed approach is much better (Figure \ref{['fig:lr_4_bad']}).
  • Figure 5: Experiment comparing RGD with baseline cross entropy Loss (CE), focal loss and class-balanced loss using a ResNet-32 backbone. $x$-axis represents the imbalance factor in the dataset.
  • ...and 1 more figures

Theorems & Definitions (5)

  • Proposition 3.1
  • Proposition 3.2
  • Proposition B.1
  • proof
  • proof