Interpretable Failure Analysis in Multi-Agent Reinforcement Learning Systems
Risal Shahriar Shefin, Debashis Gupta, Thai Le, Sarra Alqahtani
TL;DR
The paper tackles the challenge of diagnosing cascading failures in multi-agent reinforcement learning by introducing a two-stage gradient-based forensics framework. Stage 1 detects per-agent instability using the Taylor remainder of the policy-gradient cost, $\mathcal{L}_i^t(\eta_i)$, to nominate a Patient-0 candidate, while Stage 2 validates and traces upstream influence through critic derivatives, employing first-order sensitivities $G_{ij}^t$ and directional curvature $D_{ij}^t$ over causal windows to produce directed contagion graphs. Key findings include high Patient-0 detection accuracy in evaluations and the ability of Stage-2 analysis to correct downstream-first misidentifications by revealing accelerating upstream paths, with instability-occupancy metrics ($IO$) outperforming traditional reward-based signals. The approach provides actionable, gradient-level forensics for diagnosing and mitigating cascading failures in safety-critical MARL systems, enabling more reliable deployment in complex, interconnected environments.
Abstract
Multi-Agent Reinforcement Learning (MARL) is increasingly deployed in safety-critical domains, yet methods for interpretable failure detection and attribution remain underdeveloped. We introduce a two-stage gradient-based framework that provides interpretable diagnostics for three critical failure analysis tasks: (1) detecting the true initial failure source (Patient-0); (2) validating why non-attacked agents may be flagged first due to domino effects; and (3) tracing how failures propagate through learned coordination pathways. Stage 1 performs interpretable per-agent failure detection via Taylor-remainder analysis of policy-gradient costs, declaring an initial Patient-0 candidate at the first threshold crossing. Stage 2 provides validation through geometric analysis of critic derivatives-first-order sensitivity and directional second-order curvature aggregated over causal windows to construct interpretable contagion graphs. This approach explains "downstream-first" detection anomalies by revealing pathways that amplify upstream deviations. Evaluated across 500 episodes in Simple Spread (3 and 5 agents) and 100 episodes in StarCraft II using MADDPG and HATRPO, our method achieves 88.2-99.4% Patient-0 detection accuracy while providing interpretable geometric evidence for detection decisions. By moving beyond black-box detection to interpretable gradient-level forensics, this framework offers practical tools for diagnosing cascading failures in safety-critical MARL systems.
