Table of Contents
Fetching ...

Causal Covariate Shift Correction using Fisher information penalty

Behraj Khan, Behroz Mirza, Tahir Syed

TL;DR

This work tackles distribution shift in streaming/continual learning by addressing causal covariate shift across training batches. It introduces C$^{3}$, a Fisher information penalty that replaces intractable Hessian-based priors with the Fisher information matrix $I(\theta)$, enabling a tractable bound on divergence and a Tikhonov-style regularizer controlled by $\lambda$. The method achieves consistent accuracy gains across 40 benchmarks, including up to $20.3\%$ batchwise improvement and $12.9\%$ over full-dataset baselines, and demonstrates robustness under both causal and natural covariate shifts. These findings suggest C$^{3}$ as a practical tool for reliable model assessment and improved generalization in federated, continual, and AutoML contexts where data arrive in batches with evolving densities.

Abstract

Evolving feature densities across batches of training data bias cross-validation, making model selection and assessment unreliable (\cite{sugiyama2012machine}). This work takes a distributed density estimation angle to the training setting where data are temporally distributed. \textit{Causal Covariate Shift Correction ($C^{3}$)}, accumulates knowledge about the data density of a training batch using Fisher Information, and using it to penalize the loss in all subsequent batches. The penalty improves accuracy by $12.9\%$ over the full-dataset baseline, by $20.3\%$ accuracy at maximum in batchwise and $5.9\%$ at minimum in foldwise benchmarks.

Causal Covariate Shift Correction using Fisher information penalty

TL;DR

This work tackles distribution shift in streaming/continual learning by addressing causal covariate shift across training batches. It introduces C, a Fisher information penalty that replaces intractable Hessian-based priors with the Fisher information matrix , enabling a tractable bound on divergence and a Tikhonov-style regularizer controlled by . The method achieves consistent accuracy gains across 40 benchmarks, including up to batchwise improvement and over full-dataset baselines, and demonstrates robustness under both causal and natural covariate shifts. These findings suggest C as a practical tool for reliable model assessment and improved generalization in federated, continual, and AutoML contexts where data arrive in batches with evolving densities.

Abstract

Evolving feature densities across batches of training data bias cross-validation, making model selection and assessment unreliable (\cite{sugiyama2012machine}). This work takes a distributed density estimation angle to the training setting where data are temporally distributed. \textit{Causal Covariate Shift Correction ()}, accumulates knowledge about the data density of a training batch using Fisher Information, and using it to penalize the loss in all subsequent batches. The penalty improves accuracy by over the full-dataset baseline, by accuracy at maximum in batchwise and at minimum in foldwise benchmarks.

Paper Structure

This paper contains 9 sections, 18 equations, 3 tables, 1 algorithm.