Table of Contents
Fetching ...

Treatment Stitching with Schrödinger Bridge for Enhancing Offline Reinforcement Learning in Adaptive Treatment Strategies

Dong-Hee Shin, Deok-Joong Lee, Young-Han Son, Tae-Eui Kam

TL;DR

This work tackles the data scarcity and safety constraints of applying reinforcement learning to adaptive treatment strategies by introducing TreatStitch, a data-augmentation framework that creates clinically valid synthetic trajectories. It combines direct stitching of similar state representations from real trajectories with Schrödinger bridge–based bridging to connect dissimilar states, thereby expanding the offline dataset without violating clinical plausibility. The approach is backed by theoretical guarantees that stitched transitions stay close to the original data distribution, mitigating out-of-distribution risks, and is validated on the EpiCare benchmark and MIMIC-III sepsis data, where it outperforms multiple generative baselines and improves offline RL performance, especially under restricted data. The work offers a practical, model-agnostic augmentation strategy for offline ATS learning, with potential to improve safety and effectiveness of AI-driven clinical decision support systems.

Abstract

Adaptive treatment strategies (ATS) are sequential decision-making processes that enable personalized care by dynamically adjusting treatment decisions in response to evolving patient symptoms. While reinforcement learning (RL) offers a promising approach for optimizing ATS, its conventional online trial-and-error learning mechanism is not permissible in clinical settings due to risks of harm to patients. Offline RL tackles this limitation by learning policies exclusively from historical treatment data, but its performance is often constrained by data scarcity-a pervasive challenge in clinical domains. To overcome this, we propose Treatment Stitching (TreatStitch), a novel data augmentation framework that generates clinically valid treatment trajectories by intelligently stitching segments from existing treatment data. Specifically, TreatStitch identifies similar intermediate patient states across different trajectories and stitches their respective segments. Even when intermediate states are too dissimilar to stitch directly, TreatStitch leverages the Schrödinger bridge method to generate smooth and energy-efficient bridging trajectories that connect dissimilar states. By augmenting these synthetic trajectories into the original dataset, offline RL can learn from a more diverse dataset, thereby improving its ability to optimize ATS. Extensive experiments across multiple treatment datasets demonstrate the effectiveness of TreatStitch in enhancing offline RL performance. Furthermore, we provide a theoretical justification showing that TreatStitch maintains clinical validity by avoiding out-of-distribution transitions.

Treatment Stitching with Schrödinger Bridge for Enhancing Offline Reinforcement Learning in Adaptive Treatment Strategies

TL;DR

This work tackles the data scarcity and safety constraints of applying reinforcement learning to adaptive treatment strategies by introducing TreatStitch, a data-augmentation framework that creates clinically valid synthetic trajectories. It combines direct stitching of similar state representations from real trajectories with Schrödinger bridge–based bridging to connect dissimilar states, thereby expanding the offline dataset without violating clinical plausibility. The approach is backed by theoretical guarantees that stitched transitions stay close to the original data distribution, mitigating out-of-distribution risks, and is validated on the EpiCare benchmark and MIMIC-III sepsis data, where it outperforms multiple generative baselines and improves offline RL performance, especially under restricted data. The work offers a practical, model-agnostic augmentation strategy for offline ATS learning, with potential to improve safety and effectiveness of AI-driven clinical decision support systems.

Abstract

Adaptive treatment strategies (ATS) are sequential decision-making processes that enable personalized care by dynamically adjusting treatment decisions in response to evolving patient symptoms. While reinforcement learning (RL) offers a promising approach for optimizing ATS, its conventional online trial-and-error learning mechanism is not permissible in clinical settings due to risks of harm to patients. Offline RL tackles this limitation by learning policies exclusively from historical treatment data, but its performance is often constrained by data scarcity-a pervasive challenge in clinical domains. To overcome this, we propose Treatment Stitching (TreatStitch), a novel data augmentation framework that generates clinically valid treatment trajectories by intelligently stitching segments from existing treatment data. Specifically, TreatStitch identifies similar intermediate patient states across different trajectories and stitches their respective segments. Even when intermediate states are too dissimilar to stitch directly, TreatStitch leverages the Schrödinger bridge method to generate smooth and energy-efficient bridging trajectories that connect dissimilar states. By augmenting these synthetic trajectories into the original dataset, offline RL can learn from a more diverse dataset, thereby improving its ability to optimize ATS. Extensive experiments across multiple treatment datasets demonstrate the effectiveness of TreatStitch in enhancing offline RL performance. Furthermore, we provide a theoretical justification showing that TreatStitch maintains clinical validity by avoiding out-of-distribution transitions.

Paper Structure

This paper contains 58 sections, 1 theorem, 39 equations, 5 figures, 14 tables, 2 algorithms.

Key Result

Theorem 1

Let $\mathcal{F}: \mathcal{S} \times \mathcal{A} \to \mathcal{S}$ be the transition function defining the environment dynamics, and fix a norm $\|\cdot\|$ on $\mathcal{S}$. Suppose $\mathcal{F}$ is $L$-Lipschitz continuous in the state coordinate: Given the offline dataset $\mathcal{D} = \{\tau_i\}_{i=1}^N$, we construct the stitched trajectory $\tau_{\text{stit}}= \tau_B[0:t'] \cup \tau_A[t+1:T]

Figures (5)

  • Figure 1: Illustrations of adaptive treatment strategies (ATS) and reinforcement learning (RL). (a) ATS aims to identify effective treatment strategies based on a patient’s evolving symptoms. (b) Online RL learns via trial-and-error, raising safety concerns in clinical settings. (c) Offline RL learns a policy from the offline dataset, removing the need for online trial-and-error on patients.
  • Figure 2: (a) The overall workflow of our treatment stitching framework that produces an augmented dataset for enhancing offline RL. (b) The detailed process of treatment stitching, which generates stitched trajectories from existing data.
  • Figure 3: Overview of Schrödinger bridge for TreatStitch.
  • Figure 4: Comparison of various generative methods.
  • Figure 5: Distribution of bridging trajectory lengths.

Theorems & Definitions (2)

  • Theorem 1
  • proof : Proof sketch