Enhancing Performance of Explainable AI Models with Constrained Concept Refinement
Geyu Liang, Senne Michielssen, Salar Fattahi
TL;DR
This work tackles the longstanding tension between interpretability and accuracy in interpretable-by-design models by introducing Constrained Concept Refinement (CCR), which optimizes concept embeddings within a small neighborhood to improve predictive performance without sacrificing explainability. The authors provide theoretical results showing that ignoring refinement can impose a nonzero loss gap, and they prove convergence guarantees under a generative, column-orthogonal setup, including zero training loss and progressive interpretability as more data are used. The CCR framework employs a differentiable surrogate, a gradient-based update with a constrained projection, and a surrogate loss that remains amenable to backpropagation while maintaining interpretability. Empirically, CCR using CLIP-based embeddings and a dispersion step outperforms state-of-the-art explainable methods on several large benchmarks (CIFAR-10/100, ImageNet, Places365) with substantially lower compute, while providing interpretable explanations via top concepts and their weights. Overall, CCR offers a principled, scalable path to jointly enhance accuracy and interpretability in explainable AI for vision tasks.
Abstract
The trade-off between accuracy and interpretability has long been a challenge in machine learning (ML). This tension is particularly significant for emerging interpretable-by-design methods, which aim to redesign ML algorithms for trustworthy interpretability but often sacrifice accuracy in the process. In this paper, we address this gap by investigating the impact of deviations in concept representations-an essential component of interpretable models-on prediction performance and propose a novel framework to mitigate these effects. The framework builds on the principle of optimizing concept embeddings under constraints that preserve interpretability. Using a generative model as a test-bed, we rigorously prove that our algorithm achieves zero loss while progressively enhancing the interpretability of the resulting model. Additionally, we evaluate the practical performance of our proposed framework in generating explainable predictions for image classification tasks across various benchmarks. Compared to existing explainable methods, our approach not only improves prediction accuracy while preserving model interpretability across various large-scale benchmarks but also achieves this with significantly lower computational cost.
