Interpetable Target-Feature Aggregation for Multi-Task Learning based on Bias-Variance Analysis
Paolo Bonetti, Alberto Maria Metelli, Marcello Restelli
TL;DR
This paper tackles interpretable multi-task learning by marrying task clustering with mean-based feature transformation in a two-phase algorithm (NonLinCTFA). It provides a bias-variance framework for linear models trained on aggregated targets and reduced feature sets, showing variance reductions scale with the number of aggregated tasks and highlighting bias adjustments via $R^2$-type terms. The proposed method iteratively aggregates targets (Phase I) and features (Phase II) with guarantees on MSE improvement, and it is validated on synthetic data and real-world datasets (SARCOS, School, QM9, climate), demonstrating improved predictive accuracy and reduced model complexity while preserving interpretability. The approach offers a principled design for scalable, interpretable MTL with practical impact in Earth sciences and related fields. Key contributions include the theoretical bias-variance results, the two-phase aggregation algorithm with interpretability guarantees, and comprehensive empirical validation across diverse domains.
Abstract
Multi-task learning (MTL) is a powerful machine learning paradigm designed to leverage shared knowledge across tasks to improve generalization and performance. Previous works have proposed approaches to MTL that can be divided into feature learning, focused on the identification of a common feature representation, and task clustering, where similar tasks are grouped together. In this paper, we propose an MTL approach at the intersection between task clustering and feature transformation based on a two-phase iterative aggregation of targets and features. First, we propose a bias-variance analysis for regression models with additive Gaussian noise, where we provide a general expression of the asymptotic bias and variance of a task, considering a linear regression trained on aggregated input features and an aggregated target. Then, we exploit this analysis to provide a two-phase MTL algorithm (NonLinCTFA). Firstly, this method partitions the tasks into clusters and aggregates each obtained group of targets with their mean. Then, for each aggregated task, it aggregates subsets of features with their mean in a dimensionality reduction fashion. In both phases, a key aspect is to preserve the interpretability of the reduced targets and features through the aggregation with the mean, which is further motivated by applications to Earth science. Finally, we validate the algorithms on synthetic data, showing the effect of different parameters and real-world datasets, exploring the validity of the proposed methodology on classical datasets, recent baselines, and Earth science applications.
