Hybridizing mechanistic modeling and deep learning for personalized survival prediction after immune checkpoint inhibitor immunotherapy

The results presented herein were obtained using a two-step modeling process, where we first applied an established mechanistic model of ICI therapy in solid tumors to a retrospective patient cohort and then combined the output from the mechanistic model with additional, often non-mechanistic patient data to train an artificial neural network (ANN) for prediction of patient survival probabilities. This approach generates functions that describe the probability of a patient death (an event) over time. In a more general sense, this is time-to-event prediction, where we define start time as the time of first administration of ICI drug (t0), event is defined as patient death, and patients are censored after time of last follow-up. We are aware that multiple methods have been demonstrated to make time to event predictions, including continuous methods such as Cox proportional hazards37 or DeepSurv38, and also based on discrete methods including logistic hazards37,65, probability mass function (PMF)37, and PMF w/competing risk (DeepHit)66. In this work, we will focus on the logistic hazards method because this approach avoids constant hazard ratio assumptions (e.g., COX) while offering superior performance than other discrete methods with small sample sizes37, and our purpose here is to test the hypothesis that mechanistic + clinical parameters will result in superior DL model performance, as opposed to comparing performance across methods.Mechanistic mathematical modelingMathematical descriptions of the mechanistic biological and physical processes underlying checkpoint inhibitor therapy may be informed through clinically measurable quantities to predict treatment response and patient outcome67. Moreover, these biological and physical phenomena are mechanistically linked through physical laws and key feedback mechanisms. Based on this understanding, we have developed a mathematical model24 that describes the total tumor burden (ρ) over time, and is built upon 1) the key, mechanistic biological factors or processes (e.g., concentration of therapeutic T cells (ψk), intratumoral concentration of immunotherapy antibody (σ), cytokine secretion (ΛC), and ratio of immune to tumor cells over time (Λψ)) and 2) the physical factors or processes (e.g., rates of untreated tumor cell proliferation (α0) and death due to ICI therapy (μ), binding of antibody to targets (λ), specific death rate of cancer cells (λp), chemotaxis (χ), diffusion of antibodies and cytokines (DA, DC), and mass conservation that underlie ICI immunotherapy intervention (Fig. 1). By mathematically linking the relationships between these processes, our model quantifies their combined effects (and the feedback processes between them) on the time-dependent change in tumor burden (ρ) under immunotherapy intervention.The model derivation process has been extensively detailed previously24,25,26; we only present a brief overview of the key mechanisms shown in Fig. 1 and their mathematical descriptions. Fick’s law descriptions are used to obtain steady-state diffusion of antibodies (A) and cytokines (C) within the tumor and the balances of antigen–antibody interaction and cancer cell-cytokine concentrations within the tumor microenvironment according to$${D}_{{\rm{A}}}\cdot {\nabla }^{2}\sigma =\lambda \cdot \sigma \cdot \rho$$
(1)
and$${D}_{{\rm{C}}}\cdot {\nabla }^{2}C=-{\varLambda }_{{\rm{C}}}\cdot \lambda \cdot \sigma \cdot \rho ,$$
(2)
respectively. The concentration of viable tumor cells over time is a function of the tumor’s intrinsic growth rate reduced by the tumor cell kill rate due to antibody binding (i.e., checkpoint inhibitors binding to their respective ligands) and the time history of antibody uptake and binding within the tumor, and may be written as$$\frac{\partial \rho }{\partial t}={\alpha }_{0}\cdot \rho -{\lambda }_{\text{p}}\cdot \rho \cdot {\psi }_{k}\cdot {\int_{0}^{t}}\lambda \cdot \sigma \cdot \rho \cdot \text{d}{t}^{{\prime} }.$$
(3)
The time-dependent intratumoral concentration of therapeutic immune cells is a function of chemotaxis-mediated migration into and within the tumor, as a result of cytokine signaling and immune cell coupling with tumor cells as follows:$$\frac{\partial {\psi }_{k}}{\partial t}=\,-\chi \cdot \nabla \cdot \left({\psi }_{k}\cdot \nabla \cdot C\right)+{\Lambda }_{\psi }\cdot \frac{\partial \rho }{\partial t}.$$
(4)
Applying reasonable assumptions and solving these four equations leads to our master equation that mechanistically describes tumor burden over time as an outcome of immunotherapy intervention:$$\frac{d\rho {\prime} }{{dt}}=\rho {\prime} \left({\alpha }_{0}-\mu +\varLambda \mu \right)+{\rho {\prime} }^{2}\left(-\varLambda \mu \right),$$
(5)
Where \({\rho }^{{\prime} }\) is tumor volume normalized by the tumor volume at t = 0 (i.e., \({\rho }^{{\prime} }={\rho }_{t=n}/{\rho }_{t=0}\)), α0 is the intrinsic (baseline) tumor growth rate without treatment intervention, and the key model parameters (these are mathematical biomarkers: MBs) are the tumor kill rate (μ) by immunotherapy and patient anti-tumor immune state (Λ), which we defined as the coupling of immune cell activity and the tumor cell kill (i.e., immunogenicity of a tumor) scaled by the ratio of tumor cells to intratumoral immune cells at the time of treatment. For simplicity moving forward, we will drop the prime on \({\rho }^{{\prime} }\), so that Eq. 5 becomes$$\frac{d\rho }{{dt}}=\rho \left({\alpha }_{0}-\mu +\Lambda \mu \right)+{\rho }^{2}\left(-\Lambda \mu \right).$$
(6)
For our initial mechanistic model (Eq. 6), patient-specific estimation of the total patient tumor burden, mathematical biomarkers μ and Λ, and associated intrinsic growth rate α0 (that is, the average pre-treatment growth rate across all measured tumors) was performed for each patient from a previously-obtained patient cohort (n = 93; patients treated with ipilimumab on clinical trial NCT02239900) from CT imaging measurements after treatment initiation.Briefly, all pathologist-indexed lesions (according to RECIST v1.1 standards) were identified, longest and shortest lesion axes were measured while making efforts to reproduce the long axes from pathological response evaluation, and lesion volumes were approximated as spheres with diameters equal to the average of long and short axes. Total lesion volumes were summed at each timepoint (before treatment, at start of treatment, and subsequent follow-ups until death or censor) as the total indexed lesion burden over time, and lesion volumes were normalized by the baseline volume (at time of first ICI treatment when t = 0). The time-dependent version of the model was then fit to time-course lesion burden for each patient in Mathematica using the function NonlinearModelFit. Residuals were weighted under the Automatic setting in NonlinearModelFit, which assigns equal (unity) weights to all data elements. To avoid local minima, we implemented an unconstrainted global optimization method via Mathematica function NMinimize that used differential evolution and Levenberg Marquart postprocessing to refine local minima solutions via the Mathematical function FindMinimum. Deaths were observed in n = 60 patients, with time to death after last follow-up ranging from 4–947 days (median = 129 days, mean = 219 days, interquartile range = 42–315 days). Detailed patient characteristics tables may be found in24 (Table 2 for the institutional validation cohort and Supplemental Table 2 therein). α1 (i.e., growth rate between start of treatment and the time of first (1st) restaging (t1)) was also calculated by fitting the short-term model solution \((\rho (t)\,\approx \,{e}^{{\alpha }_{1}\,\cdot \,t})\) between the CT imaging measured tumor burden (ρ) measured at time of treatment initiation (t = 0) and at the time of first restaging (t1), calculated as α1 ≈ ln(ρ(t1))/t1.Finally, the time-dependent model form was fit to the time series volume measurements in Mathematica using function NonlinearModelFit to obtain values for model parameters μ and Λ (example fits are shown inset in Fig. 2A); interested readers may obtain a working script to fit the model to a sample data set in ref. 68. No parameter constraints or starting values used in the regression analysis were specific to this data set, as this could result in data leakage from previous studies on the entire patient cohort. The only constraints used were based on correct physical descriptions (e.g., tumor volume cannot be negative), were based on mathematical considerations specific to the model (e.g., tumor volume can be very small but not == 0, as this can cause numerical solutions to become undefined due to division by zero), or have been determined to be consistent across multiple independent patient cohorts spanning multiple primary tumor histologies and checkpoint inhibitor drugs, including five in-house patient cohorts26 and 6 cohorts derived from the literature25, thereby providing good evidence they are not specific to only this patient cohort.After solving for these key parameters from the mechanistic mathematical model, we combined them with a set of other clinical measures (Table 1) to train and validate a deep learning (DL) model for predicting individual patient survival. The overall approach is shown in Fig. 2. At this time, we have only included data that were readily available and complete (that is, no missing measurements for any patient) in order to test our hypothesis that a hybrid (mechanistic + DL) approach using both model-derived and clinical data may improve accuracy over either method alone; however, we recognize that it is unlikely that this constitutes the set of model + clinical measures that would enable the DL model to achieve the maximum theoretical predictive accuracy if all possible measures were available. By using a patient cohort collected as part of an in-house clinical trial, we have been able to collect a broad set of individual patient parameters that we hope captures much of the pertinent information about patient prognosis; however, this decision has made it difficult to find matching external validation sets, and as a result we have generated a validation set by withholding a subset of the patients from model training. Note that the neutrophil to lymphocyte ratio (NLR; a commonly reported clinical measure) was not included in our DL model to avoid collinearity with neutrophil and lymphocyte counts and improve the accuracy of the feature importance analysis (this decision is elaborated on in Discussion).Deep learning modelingModeling approach overviewSurvival predictions generated herein are based on a methodology first published by Brown69, who observed that the standard Kaplan-Meier survival curve may be thought of as a set of discrete binary states (where each patient has a unique probability of being alive or deceased at each ‘step’ in the curve). Thus, a Kaplan-Meier curve may be approximated by a series of binary functions over time. Commonly, binary classification prediction is done using logistic regression functions. By subdividing the time domain into a series of discrete windows, each containing a set of observed patient deaths (these are the events to be predicted), a unique logistic function may be fit to each concordant time increment by minimizing its loss function (Bernoulli’s negative log-likelihood in the case of logistic regression with right-censored data) to the data subset contained within that increment. Machine or deep learning (ML/DL) is a natural choice for this approach, as binary classification is usually based on logistic regression when done with ML/DL37, and ML/DL platforms are robust tools to process many types of data.Data preprocessingData were pre-processed using standard ML approaches; briefly, categorical data were one-hot encoded, data were split into training, validation, and test sets as shown in Fig. 2B, and then continuous data were normalized using the StandardScalar function from scikit70 based on the training set only. It can be observed in Table 1 that the distribution of race is highly unbalanced; however, no feature balancing (e.g., resampling via oversampling, SMOTE, SMOTE-N, etc44 or cost-sensitive weighting48) was performed in this study; implications of this approach are further examined in Discussion. To validate the model, k-fold validation was implemented, where continuous data were renormalized for each fold (normalization should be done based on only the unique training set chosen for each fold to avoid ‘data leakage’ from the validation and test sets into the training set). Finally, data was subdivided into time increments (n = 20 increments was used in this study) in order to generate a unique DL-generated logistic curve for each increment (Fig. 2D); each ‘step’ in the curves shown in Fig. 3D corresponds to one discrete time increment37. Time discretization also enables more accurate prediction within each time window by providing a method to account for censoring events while also reducing the potential skewing effects of concentrated events (in our case, death or censor) observed at distant times (before or after the time increment)37. We used an equidistant time discretization scheme37 for the results presented here (alternatively, this can also be done based on event distribution via Kaplan-Meier quartiles). All data preprocessing was conducted using scikit70, pandas71, NumPy72, and torchtuples, and plots were generated using matplotlib73 and Plotly74.Model training and statistical analysisAfter data handling, a multilayer perceptron ANN was trained using the logistic hazards time-to-event modeling37,65 functionality from the PyCox package37,75 and PyTorch. For the results presented here, the ANN was constructed with a single hidden layer (multilayer perceptron ANNs are able to approximate the mapping of any function from one finite space to another with a single hidden layer76), rectified linear unit (ReLU) activation functions were optimized by Adaptive Moment Estimation, and training epochs were terminated when the error function of the validation set was minimized (Supplemental Fig. 4). For simplicity, the number of nodes in the hidden layer was selected to be the arithmetic mean between nodes in the input layer (n = number of features) and the output layer (n = 20), rounded down to the nearest integer if needed. Hyperparameters were tuned using randomized search with cross-validation using scikit-learn 1.4.1 modules BaseEstimator and RandomizedSearchCV with a target of minimizing the IPCW Brier score (see below), followed by manual hyperparameter tuning as needed. Other hyperparameters used for the hybrid ML + mechanistic model include batch normalization, a dropout rate of 0.2 (prevents overfitting), per-epoch batch size of 50 (roughly 85% of the training set), a learning rate 0.07, and early stopping was enabled (for this small study, all training cycles finished under 512 epochs). No feature selection or elimination was performed in this study; all available features were included in all analyses regardless of importance. No tuning was performed for the randomized data grouping (patients being assigned to train, validation, or test sets) shown in Fig. 2B.At all times, the ANN training algorithm remained agnostic to the test set, which was then used to evaluate the predictive accuracy of the trained neural network after training was complete. The trained ANN was tasked with predicting survival for each patient in the test set; individual patient parameters for the test set are shown in Supplementary Table 1. The accuracy of per-patient survival predictions was then assessed using 1) an event-time concordance index39 (similar to the standard C-index, but calculated over time based on an extension of the C-index with right censoring proposed by Harrel77), 2) the time-dependent error in predicted hazard functions using the inverse probability of censoring weighting (IPCW; a method of accounting for right-censored events in the data by approximating the score based on the inverse probability of censoring) Brier score (a method of scoring error that compares predicted likelihoods of an event vs. if it was observed or not by calculating the mean squared difference between the predicted probability and the actual outcome)29,35, and 3) the IPCW negative binominal log-likelihood35 (NBLL; a log-likelihood estimate of prediction error weighted by the inverse of the censoring distribution). The Brier score assesses accuracy based on the mean square distance between predicted probability and measured outcomes, but outcomes are not known in the case of censor events; weighting by the inverse probability of censoring weights Brier scores in a way that retains their original interpretation under right-censor conditions35. Both IPCW Brier score and IPCW negative binomial log-likelihood are integrated across the time dimension to provide a single score that is reported herein (see Supplemental Fig. 3 for more details). Note that IPCW NBLL is similar to the loss function for DL model training (the mean negative log-likelihood of the hazard parameterization model37), but weighted by the inverse probability of censoring.In order to test the stability and reliability of the trained hybrid mechanistic + DL model and associated feature importances in the ANN, we performed a k-fold validation where the steps shown in Fig. 2B–D were repeated for all permutations of training, validation, and test sets by dividing the data into n = 5 groups of ~20% patients in each group (n = 20 total folds). All hyperparameters were held constant for each fold. Model stability was assessed by comparing the loss function (here, the loss function is the mean negative log likelihood as described in ref. 37) for the validation cohort (the set that determines when model training stops) for each k-fold; note that for this application, error in the loss function for the held-out data set is preferable to the C-index or other statistical evaluators of model accuracy, which may not be appropriate for distinguishing between models36. The stability of the time-dependent C-index was also examined across folds, and descriptive statistics describing variations between folds were calculated. K-fold validation was also performed on DL models trained on only clinical data and only MBs under the same protocol described above.Characterizing the trained hybrid deep learning model by feature importancesFinally, we sought to quantify the feature importances; that is, how much each model input (these are features: both mechanistic model parameters and clinical measures) contributes to how the artificial neural network makes predictions based on 1) loss function minimization (as a surrogate for direct statistical accuracy, under the assumption that minimizing the loss function maximizes accuracy) and 2) statistical methods that directly calculate rubrics of model accuracy. This is possible because the DL model is trained by minimizing the negative log-likelihood of the hazard parametrization model loss function, while model predictive accuracy may be directly evaluated by using event-time concordance, ICPW Brier score, and ICPW negative binomial log-likelihood. It is important to note that these two approaches provide complimentary but distinct information, as the loss function is minimized based on the training and validation sets (thus evaluating importances based on how the trained neural network makes predictions), while statistical evaluation of prediction accuracy is performed using the test set (these data are withheld during model training, enabling evaluation of feature importance evaluation based on the accuracy of model predictions; see Fig. 2), and feature importance analysis may be performed independently against each of these functions. The result is a two-pronged approach to robustly evaluate if both clinical and mechanistic model-derived measures are advantageous within a single DL model for maximum predictive accuracy. Because the DL method used here results in a discrete logistic regression curve for each time increment (e.g., Fig. 2D), there is an associated feature importance for each distinct logistic curve (corresponding to each time increment). This is analogous to classification models where a unique output node is generated for each possible label, and likewise features may contribute in different ways to each label. As a result, feature importance analysis yields an array sized [number of features] × [number of outputs]. These may either be examined individually for each output node (in our case, a single logistic function corresponding to one ‘step’ in Fig. 3), which we refer to herein as local feature importance, or they may be combined across all output nodes (we have summed them here, but averages could also be used) to determine their effects on the full model; we will refer to these as global importances.We used standard methods to study feature importances from the ANN trained via loss function minimization, including several back propagation methods: integrated gradients78, integrated gradients with smoothing79, DeepLift80, and DeepLitfSHAP80,81 (approximates SHapley Additive exPlanations (SHAP) values using the DeepLift approach); feature ablation (a method of determining importances via feature perturbation); and SHAP values81 (an approach derived from game theory that attempts to maximize the gain from each player while ensuring the gain is at least as much as each player would yield independently). In order to directly examine how much each feature contributes to the statistical accuracy when the trained ANN makes survival predictions on new patient data, we passed the event-time concordance evaluator, IPCW Brier score evaluator, and IPCW negative log-likelihood evaluator to a feature permutation algorithm using a wrapper from the eli5 package, which estimates importances using a leave-one-out approach and any black-box scoring rubric. Importantly, in order to preserve the integrity of the trained ANN between calls to different importance calculation algorithms, we first saved the trained ANN to the hard drive immediately after training was complete, and then the original ANN was re-imported in between each importance calculation. Feature importance calculations were done using the Captum82, SHAP81, and eli583 packages.

Hot Topics

Related Articles