Table of Contents
Fetching ...

Improving Generalization via Meta-Learning on Hard Samples

Nishant Jain, Arun S. Suggala, Pradeep Shenoy

TL;DR

The paper tackles generalization gaps in supervised learning by meta-optimizing the LRW framework's validation data, hypothesizing that hard validation samples yield stronger generalization. It introduces MOLERE, a scalable tri-level/minimax approach that jointly learns a splitter for hard samples and an instance-weighting meta-network, enabling end-to-end LRW training on hard-validation splits. The method achieves consistent gains over ERM and various reweighting baselines across in-domain and domain-shift datasets, including robust performance with large pretrained models and in noisy-label and skewed-label regimes, and demonstrates a margin-maximization effect. This work establishes meta-optimization of meta-learning as a viable path to improve generalization in supervised learning with practical implications for robust model training and domain adaptation.

Abstract

Learned reweighting (LRW) approaches to supervised learning use an optimization criterion to assign weights for training instances, in order to maximize performance on a representative validation dataset. We pose and formalize the problem of optimized selection of the validation set used in LRW training, to improve classifier generalization. In particular, we show that using hard-to-classify instances in the validation set has both a theoretical connection to, and strong empirical evidence of generalization. We provide an efficient algorithm for training this meta-optimized model, as well as a simple train-twice heuristic for careful comparative study. We demonstrate that LRW with easy validation data performs consistently worse than LRW with hard validation data, establishing the validity of our meta-optimization problem. Our proposed algorithm outperforms a wide range of baselines on a range of datasets and domain shift challenges (Imagenet-1K, CIFAR-100, Clothing-1M, CAMELYON, WILDS, etc.), with ~1% gains using VIT-B on Imagenet. We also show that using naturally hard examples for validation (Imagenet-R / Imagenet-A) in LRW training for Imagenet improves performance on both clean and naturally hard test instances by 1-2%. Secondary analyses show that using hard validation data in an LRW framework improves margins on test data, hinting at the mechanism underlying our empirical gains. We believe this work opens up new research directions for the meta-optimization of meta-learning in a supervised learning context.

Improving Generalization via Meta-Learning on Hard Samples

TL;DR

The paper tackles generalization gaps in supervised learning by meta-optimizing the LRW framework's validation data, hypothesizing that hard validation samples yield stronger generalization. It introduces MOLERE, a scalable tri-level/minimax approach that jointly learns a splitter for hard samples and an instance-weighting meta-network, enabling end-to-end LRW training on hard-validation splits. The method achieves consistent gains over ERM and various reweighting baselines across in-domain and domain-shift datasets, including robust performance with large pretrained models and in noisy-label and skewed-label regimes, and demonstrates a margin-maximization effect. This work establishes meta-optimization of meta-learning as a viable path to improve generalization in supervised learning with practical implications for robust model training and domain adaptation.

Abstract

Learned reweighting (LRW) approaches to supervised learning use an optimization criterion to assign weights for training instances, in order to maximize performance on a representative validation dataset. We pose and formalize the problem of optimized selection of the validation set used in LRW training, to improve classifier generalization. In particular, we show that using hard-to-classify instances in the validation set has both a theoretical connection to, and strong empirical evidence of generalization. We provide an efficient algorithm for training this meta-optimized model, as well as a simple train-twice heuristic for careful comparative study. We demonstrate that LRW with easy validation data performs consistently worse than LRW with hard validation data, establishing the validity of our meta-optimization problem. Our proposed algorithm outperforms a wide range of baselines on a range of datasets and domain shift challenges (Imagenet-1K, CIFAR-100, Clothing-1M, CAMELYON, WILDS, etc.), with ~1% gains using VIT-B on Imagenet. We also show that using naturally hard examples for validation (Imagenet-R / Imagenet-A) in LRW training for Imagenet improves performance on both clean and naturally hard test instances by 1-2%. Secondary analyses show that using hard validation data in an LRW framework improves margins on test data, hinting at the mechanism underlying our empirical gains. We believe this work opens up new research directions for the meta-optimization of meta-learning in a supervised learning context.
Paper Structure (33 sections, 1 theorem, 17 equations, 4 figures, 7 tables, 1 algorithm)

This paper contains 33 sections, 1 theorem, 17 equations, 4 figures, 7 tables, 1 algorithm.

Key Result

Theorem 1

Consider the tri-level optimization in Equation eqn:tri_level. Suppose the weighting function $\phi(\cdot),$ and splitting function $\Theta(\cdot)$ are dependent on both $x$ and $y$. Let's suppose $N+M\to\infty$, and $\lim_{N,M\to\infty}\frac{M}{N+M} = \delta$. Moreover, suppose the domains of $\phi

Figures (4)

  • Figure 1: Robustness analysis on benchmark datasets.Left: Comparing different LRW variants, based on the choice of validation set (Easy, Random, Hard, corresponding to the rank-ordering of training data by probabilistic margin of an ERM classifier). $y$-axis shows accuracy gains over ERM for each dataset ($x$-axis). We see consistent ordering of performance, with LRW-Easy $<$ LRW-Random $<$ LRW-Hard, showing the importance of validation set optimization. Right: Comparing against other re-weighting methods. The figure shows that our proposal (LRW-Hard) outperforms the other reweighting techniques on average, with fast sample re-weighting (FSR) begin competitive in some datasets. In-1K corresponds to ImageNet-1K. For absolute accuracy values refer supplementary.
  • Figure 2: OOD generalization.Left: Comparison of LRW variants on domain shift benchmarks. The ordering between the validation selection methods is reconfirmed on domain shift benchmark datasets as well, suggesting that earlier gains are not via overfitting to training distribution. Right: Comparing against other re-weighting methods. The figure shows that our proposal (LRW-Hard) outperforms the other reweighting techniques on average, with fast sample re-weighting (FSR) begin competitive in some datasets. For absolute values refer supp.
  • Figure 3: MOLERE improves margins of learned classifiers. (a,b): paired margin deltas between LRWOpt and ERM are moderately right-skewed with mean/median greater than zero. (c,d): As a function of ERM margin, clear separation seen between LRW-Hard (better) and LRW-Easy (worse) in terms of margin gain over ERM (errorbars are SEM). All results on unseen test data; Imagenet-100, CIFAR100 shown for brevity, with similar results for other datasets in supplementary.
  • Figure 4: Top. Histograms of difference in margin of the LRW-Hard trained classifier and ERM classifier. Bottom. Mean and Standard deviation of margin deltas between the LRW Hard/Easy methods and the ERM classifier on the test examples, binned by ERM classifier margins with a bin width of 0.2 units. On the x-axis is the starting of the margin interval bin of the ERM classifier. We see that LRW-Hard classifiers have a tendency to increase margin over and above the ERM margin, whereas LRW-easy classifiers appear to reduce margin.

Theorems & Definitions (3)

  • Theorem 1: Asymptotics
  • proof : Proof Sketch.
  • proof