Table of Contents
Fetching ...

On Fisher Consistency of Surrogate Losses for Optimal Dynamic Treatment Regimes with Multiple Categorical Treatments per Stage

Nilanjana Laha, Nilson Chapagain, Victoria Cicherski, Aaron Sonabend-W

TL;DR

The paper investigates Fisher consistency for surrogate losses used to learn optimal dynamic treatment regimes (DTRs) across multiple stages and treatment levels per stage. It first shows that many concave surrogates, including broad PERM families, are not Fisher consistent in the multi-stage setting, motivating a move beyond concave surrogates. It then establishes necessary and sufficient conditions for Fisher consistency within nonnegative, stagewise-separable surrogates, and constructs smooth, non-concave surrogates (kernel-based and product-based) that are Fisher consistent. Building on these surrogates, the authors introduce Simultaneous Direct Search with Surrogates (SDSS), a gradient-based optimization method for learning DTRs across all stages, along with regret decay results under small-noise and smoothness assumptions. Empirical evaluation via simulations and a sepsis EHR study demonstrates SDSS’s potential advantages over stagewise methods like Q-learning and ACWL, particularly in settings with high misspecification risk or high-dimensional noisy covariates. These findings advance understanding of surrogate design for multi-stage DTRs and offer a scalable framework for model-free, simultaneous optimization of stage-wise treatments.

Abstract

Patients with chronic diseases often receive treatments at multiple time points, or stages. Our goal is to learn the optimal dynamic treatment regime (DTR) from longitudinal patient data. When both the number of stages and the number of treatment levels per stage are arbitrary, estimating the optimal DTR reduces to a sequential, weighted, multiclass classification problem (Kosorok and Laber, 2019). In this paper, we aim to solve this classification problem simultaneously across all stages using Fisher consistent surrogate losses. Although computationally feasible Fisher consistent surrogates exist in special cases, e.g., the binary treatment setting, a unified theory of Fisher consistency remains largely unexplored. We establish necessary and sufficient conditions for DTR Fisher consistency within the class of non-negative, stagewise separable surrogate losses. To our knowledge, this is the first result in the DTR literature to provide necessary conditions for Fisher consistency within a non-trivial surrogate class. Furthermore, we show that many convex surrogate losses fail to be Fisher consistent for the DTR classification problem, and we formally establish this inconsistency for smooth, permutation equivariant, and relative-margin-based convex losses. Building on this, we propose SDSS (Simultaneous Direct Search with Surrogates), which uses smooth, non-concave surrogate losses to learn the optimal DTR. We develop a computationally efficient, gradient-based algorithm for SDSS. When the optimization error is small, we establish a sharp upper bound on SDSS's regret decay rate. We evaluate the numerical performance of SDSS through simulations and demonstrate its real-world applicability by estimating optimal fluid resuscitation strategies for severe septic patients using electronic health record data.

On Fisher Consistency of Surrogate Losses for Optimal Dynamic Treatment Regimes with Multiple Categorical Treatments per Stage

TL;DR

The paper investigates Fisher consistency for surrogate losses used to learn optimal dynamic treatment regimes (DTRs) across multiple stages and treatment levels per stage. It first shows that many concave surrogates, including broad PERM families, are not Fisher consistent in the multi-stage setting, motivating a move beyond concave surrogates. It then establishes necessary and sufficient conditions for Fisher consistency within nonnegative, stagewise-separable surrogates, and constructs smooth, non-concave surrogates (kernel-based and product-based) that are Fisher consistent. Building on these surrogates, the authors introduce Simultaneous Direct Search with Surrogates (SDSS), a gradient-based optimization method for learning DTRs across all stages, along with regret decay results under small-noise and smoothness assumptions. Empirical evaluation via simulations and a sepsis EHR study demonstrates SDSS’s potential advantages over stagewise methods like Q-learning and ACWL, particularly in settings with high misspecification risk or high-dimensional noisy covariates. These findings advance understanding of surrogate design for multi-stage DTRs and offer a scalable framework for model-free, simultaneous optimization of stage-wise treatments.

Abstract

Patients with chronic diseases often receive treatments at multiple time points, or stages. Our goal is to learn the optimal dynamic treatment regime (DTR) from longitudinal patient data. When both the number of stages and the number of treatment levels per stage are arbitrary, estimating the optimal DTR reduces to a sequential, weighted, multiclass classification problem (Kosorok and Laber, 2019). In this paper, we aim to solve this classification problem simultaneously across all stages using Fisher consistent surrogate losses. Although computationally feasible Fisher consistent surrogates exist in special cases, e.g., the binary treatment setting, a unified theory of Fisher consistency remains largely unexplored. We establish necessary and sufficient conditions for DTR Fisher consistency within the class of non-negative, stagewise separable surrogate losses. To our knowledge, this is the first result in the DTR literature to provide necessary conditions for Fisher consistency within a non-trivial surrogate class. Furthermore, we show that many convex surrogate losses fail to be Fisher consistent for the DTR classification problem, and we formally establish this inconsistency for smooth, permutation equivariant, and relative-margin-based convex losses. Building on this, we propose SDSS (Simultaneous Direct Search with Surrogates), which uses smooth, non-concave surrogate losses to learn the optimal DTR. We develop a computationally efficient, gradient-based algorithm for SDSS. When the optimization error is small, we establish a sharp upper bound on SDSS's regret decay rate. We evaluate the numerical performance of SDSS through simulations and demonstrate its real-world applicability by estimating optimal fluid resuscitation strategies for severe septic patients using electronic health record data.

Paper Structure

This paper contains 161 sections, 66 theorems, 842 equations, 9 figures, 14 tables, 1 algorithm.

Key Result

theorem 3.1

Suppose $T=2$, $k_1,k_2\geq 2$, and $\psi$ is an above-bounded concave PERM loss such that $\cap_{i=1}^{k_1}\cap_{j=1}^{k_2}\iint(\dom(-\psi(\cdot;i,j))\neq \emptyset$. Further suppose $-\eta$ is proper, closed, and strictly convex, where $\eta$ is the template of $\psi$. Also, we assume $\eta$ is t

Figures (9)

  • Figure 1: Plot of $\Gamma(x,y;1)$ when $k=3$. For the product-based surrogate, $\Gamma$ is as in \ref{['def: Gamma: product bases']}, and $\tau(x)=(1+\tanh(x))/2$, which is the distribution function of the centered logistic distribution with scale $2$. For kernel-based surrogates, the template $\Gamma$ is provided in \ref{['def: Gamma: kernel based']}. Its closed formulas for the logistic and Gumbel densities are provided in Supplement \ref{['sec: short: calc for kernel logistic']}.
  • Figure 2: Plots related to $\widehat{V}^{\psi,\text{rel}}$ for the toy example in Section \ref{['sec: implementation']}. The formula of $\widehat{V}^{\psi,\text{rel}}(x,y)$ is provided in \ref{['opti: value fn: toy data']} and the toy data is provided in Table \ref{['table:toy_data']}. In plot (c), the logarithm base is 10. Higher negative values in this plot indicate plateau regions, where the gradients become close to zero.
  • Figure 3: Gradient descent for the toy data in Section \ref{['sec: implementation']} The plots display the paths traced by iterates initiated from 6 different initialization points for (a) vanilla gradient descent, (b) ADAM (without minibatching), (c) SGD, and (d) ADAM with SGD. The white circle and the solid black rectangle mark the starting point and the end point of the paths, respectively. Although the algorithms minimize $-\hVr$, the paths are shown on the contour plot of $\hVr$ (also provided in Figure \ref{['fig: contour']}). The legend to the right presents the color scale. The yellow region indicates the optimal plateau, where the objective value is close to the optimum. The first three plots used 5,000 iterations, while the last used 20,000. The SGD batch size was one. The ADAM parameters were set to their default values as specified in Algorithm \ref{['alg:multi_stage_opt']}.
  • Figure 4: Optimal treatment assignments for Scheme 2 as a function of $O_1^T\mo_p$. As mentioned in Section \ref{['sec: simulation']}, the optimal treatment assignments in Scheme 2 are functions of $O_1^T\mo_p$ only. The X-axis represents $O_1^T\mo_p$, while the Y-axis indicates the corresponding optimal treatment. Here, the optimal treatment assignments are either two or three, which are colored blue and red, respectively.
  • Figure 5: Boxplot of value functions: Boxplot of estimated value function of different methods for 100 replications. The value functions in Scheme 1 are scaled by $10^{-2}$ for better visual comparability. The X-axis indicates the sample size and the Y-axis represents the average value function over 100 replications. Each box shows the interquartile range (IQR), with edges at the 25th and 75th percentiles and a line for the median. Whiskers extend to 1.5$\cdot$IQR. The red dashed line represents the optimal value function. In Scheme one, $\omega$ corresponds to the complexity parameter, and in Scheme two, $p$ corresponds to the dimension of the baseline covariate space.
  • ...and 4 more figures

Theorems & Definitions (144)

  • definition 2.1: DTR Fisher Consistency w.r.t. $\mP$
  • theorem 3.1
  • lemma 4.1
  • lemma 4.2
  • theorem 4.1: Sufficient conditions
  • theorem 4.2: Necessary conditions
  • proposition 1
  • remark 1
  • lemma 4.3
  • lemma 4.4
  • ...and 134 more