Time-to-event (TTE) predictions are extensively used by medical statisticians. Traditional methods of logistic regression are not suited to include both the event and time aspects as the outcome in the model. Non-parametric models such as the Kaplan-MeierKaplan and Meier (1958) estimator and the semi-parametric Cox proportional hazard models and its extentsions Cox (1972); Recknor and Gross (1994)
face the challenge of adjusting for multiple/time-varying covariates. The recent development of data adaptive models such as the deep neural networksGensheimer and Narasimhan (2019) and Super-Learner Golmakani and Polley (2020) enable the efficient estimation of individual survival curves with static and longitudinal data, yet relatively little has been written about the implication of these explanatory techniques in the context of event time prediction.
The strength of explanatory survival analysis has been applied in data adaptive predictive models to improve the estimation accuracy of survival curves. In Deephit Lee et al. (2020)
, a rank loss function is designed to evaluate whether the model can order observations by their expected time to fail; in DeepSurvKatzman et al. (2018), authors approximates the Cox proportional hazard function using a densely connected neural network; and in WTTE-RNN Martinsson (2016)
, the predicted event time is assumed to follow a Weibull distribution whose parameters is estimated using a recurrent neural network. These model are in contrast with conventional binary predictors such as the recurrent neural networks proposed in the 2019 PhysioNet ChallengeReyna et al. (2019), where the prediction of TTE was equated to a longitudinal binary classification problem.
In our recently proposed counterfactual dynamic survival model (CDSM), we relaxed major limitations of the three models aforementioned. Specifically, we do not assume Cox proportional hazard ratio or any parametric assumption in our model. At the same time, we allow longitudinal covariates and quantify the uncertainty of the neural network estimations using Bayesian dense layers. The focus of our previous work is the model development and its application to causal inference. In this study, we fixate on the prediction power of CDSM as an outcome model and explore how biologically plausible survival curve estimations can improve the TTE predictions.
In Section 2, we describe the methodology to estimate the survival outcomes and predict the time to event. Section 3 introduces a set of case studies and model evaluation techniques. Results are presented in section 4. We end our study with a discussion.
2 Predicting the time to event with survival curves
To formalize the framework for longitudinal survival outcomes, we follow the notations in previous studies Imai and Strauss (2011); Zhu and Gallego (2020). Suppose we observe a sample independently generated by an unknown distribution :
where is the time at the upper limit of each time interval (i.e., hours, months and years), is the maximum of patients’ follow-up time and are baseline covariates at time ; is the exposure condition at time , if observation i receives the treatment and otherwise; denotes the outcome at time , if experienced an event and otherwise; is determined by the event or censor time, or whichever happened first.
For each individual , we define the conditional hazard rate
as the probability of failure in intervalas:
where and are the history of treatments and covariates until time . The conditional probability of surviving to the end of interval
is given by the probability chain rule:
We define our target outcome similar to a multivariate logistic regression but with an additional term to capture the event and censoring:
is the time index of vector.
Conventional predictive models fit the multivariate logistic outcome in Equation (3
) using binary classifiers, where researchers have to set the optimal probability threshold to classify whether an event will occur (see hazard threshold in Figure1). For instance, one can use the Nelder-Mead method to locate the optimal probability threshold via minimizing the distance between the actual and predicted event time, yet this in-sample threshold might not be the optimal for predicting the TTE on a new cohort.
This study attempts to learn from the biological survival curve and uses the inflection point, , of the survival curve in Equation (2) to signify the event time, which we define as the time point equateing the second derivative of estimated survival curve to zero:
In Figure 1, we can see the hazard rate has a rapid increase after , which means a high probability of experiencing an event. The uncertainty of the estimated survival curve quantifies the uncertainty of the predicted event time.
3 Study design and databases
We built and then validated a survival outcome model based on the retrospective analysis of three static databases and three dynamic longitudinal databases. The summary of these data sets are presented in Table 1.
|Database||Sample||Covariates||Unique Time Points||% Censored|
|CPRD AF||18102||53||20*3 months||82%|
The static data sets were provided by the DeepSurv python package Katzman et al. (2018) which includes:
The Study to Understand Prognoses and Preferences for Outcomes and Risks of Treatments (SUPPORT);
The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC); and
The Rotterdam tumor bank and German Breast Cancer Study Group (GBSG).
The longitudinal data sets are:
The Medical Information Mart for Intensive Care version III (MIMIC-III), an open-access, anonymized database of 61,532 admissions from 2001–2012 in six ICUs at a Boston teaching hospital Johnson et al. (2016) .
2019 PhysioNet Sepsis prediction challenge data set Reyna et al. (2019) (PhysioNet). PhysioNet Sepsis prediction c containing more than 3.3 million admissions from 2003–2016 in 459 ICUs across the United States.
The Clinical Practice Research DataLink data set Herrett et al. (2015) (CPRD AF) comparing Vitamin K Antagonists (VKAs) and Non-Vitamin K antagonist oral anticoagulants (NOAC) in preventing three combined outcomes (ischemic attack, major bleeding and death) of patients with non-valvular atrial fibrillation (AF).
In both MIMIC-III and PhysioNet, we define Sepsis event as a suspected infection (prescription of antibiotics and sampling of bodily fluids for microbiological culture) combined with evidence of organ dysfunction, defined by a two points deterioration of SOFA score Seymour (2016). We follow previous papers Reyna et al. (2019); Komorowski et al. (2018) for data extraction and processing. For the PhysioNet, we combined data from hospital A and B, and used hospital location (A or B) as the synthetic treatment condition. For MIMIC-III, we define the treatment as the usage of mechanical ventilation(MV). For the CPRD AF, the outcome of interest is the first occurrence of combined outcomes of major bleeding, death and stroke. The treatment is the usage of NOAC vs the control of using VKAs.
For the static data sets, we discretized the time points into windows of 50 time steps and censored all steps do not form a complete window (i.e. windows for the SUPPORT data set, windows for METABRIC and windows for GBSG). For the longitudinal data sets, we considered the first 20 time stamps for each patient (i.e., the first 20 2-hour intervals for PhysioNet and MIMIC-III, and the first 20 months for the AF.). We split each database into estimation data set (70% of the original data for training and 10% validation) and testing data set (20% of the original).
3.1 Model evaluation
We performed an evaluation of the estimations of survival curves and predictions of event time using the three metrics described below:
The area under the receiver operating characteristic (AUROC) and C-Index: we use AUROC and Harrell’s C-index Frank et al. (1982) to evaluate the models’ discrimination performance. Both indicators are calculated using the multivariate logistic outcome in Equation (3) .
Utility distance (Distance Score): we define the distance metric to evaluate the predicted event time as:
where is defined in Equation (4) and is the true event/censoring time.
We compared following algorithms on the estimation of survival curves and the prediction of event time:
Dynamic Bayesian survival causal model (D-Surv): the model targets the outcome defined in Equation (3
) by training two counterfactual sub-networks for treated and controlled observations. If no treatment variable is defined, we create two copies of the original data set, with first one marked as receiving the treatment and the second one as under control. The loss function of D-Surv has three components: 1) the partial log likelihood loss of the joint distribution of the first hitting time and corresponding event or right-censoring; 2) the rank loss function to capture the concordance score defined in survival analysis; 3) The calibration loss function minimizes the selection bias in for treatment assignment. Please refer to our previous paper for details.
Plain recurrent neural network with survival outcomes (RNN): the model modifies the D-Surv by removing the counterfactual sub-networks and the third loss function in D-Surv. No treatment variable has to be specified in this model.
Plain recurrent neural network with binary outcomes (RNN Binary): the model provides the direct prediction on the longitudinal outcome in Equation (3) using the mean squared error loss function.
DeepHit Lee et al. (2018): the model uses the same loss functions as the RNN but does not capture the history of covariates and is only evaluated for static databases.
The model construction and training uses Python 3.8.0 with Tensorflow 2.3.0 and Tensorflow-Probability 0.11.0Abadi et al. (2015) (code available at https://github.com/EliotZhu/DSurv).
In Table 2, we confirmed CDSM, RNN and DeepHit had similar performance on the estimation of survival curves (see the concordance index) and the prediction of event time (see the distance score) in the three static testing data sets. However, in terms of the AUROC, we noticed RNN Binary had superior performance than the others, although it had lower C-Index.
The counterfactual sub-networks and the selection bias calibration loss function in CDSM did not affect the estimation accuracy, resulting the equivalency among CDSM, RNN and DeepHit in the static non-causal survival estimations.
All metrics are averaged over estimation windows using testing data sets. The best value in each metric is in bold.
Similar trend was observed when we evaluated CDSM, RNN, and RNN Binary using the longitudinal databases (see the estimation data set evaluations in Table 3). However, in the corresponding testing data sets, D-Surv significantly outperformed RNN Binary, especially for the C-Index and distance score. We saw the imposition of survival outcome in Equation (3) and concordance loss functions defined in D-Surv/RNN produced nominal survival curves, where the RNN Binary only maximized the discrimination performance on the binary indicator of whether the sepsis has occurred (i.e., the estimated survival probabilities for the AF testing data set were stacked at zeros and ones as shown in Figure 2 (a)).
All metrics are averaged over 20 estimation windows using either estimation (the default) or testing data sets (specified in brackets). The best value in each metric is in bold.
|Dataset||PhysioNet (hours)||MIMIC-III (hours)||AF (months)|
|Metrics||CDSM||RNN||RNN Binary||CDSM||RNN||RNN Binary||CDSM||RNN||RNN Binary|
|Distance Score (test)||3.047||3.635||3.767||2.291||2.375||2.339||1.116||1.035||1.240|
|Score Std (test)||11.522||6.813||0.080||1.743||0.557||0.038||3.118||1.469||0.011|
The nominal survival curves by CDSM made it possible to apply Equation (4) to locate the inflection point as the event time. This is a better approach than choosing a probability threshold to construct a Binary classifier. In Figure 2 (b), we saw the error of predicted event time is sensitive to the chosen probability threshold, where the range of average timing difference was from -4.3 to 9.5 months in a small threshold range: 0.99 to 0.9999. In contrast, after applying the inflection point to determine the event time, we observed the predicted time accurately tracked the true time in Figure 2 (c), with most predictions happened ahead of the true AF event time. The average distance to from the predicted time is 1.720 months ahead of the true AF event time, while 1.014 months ahead for true AF censored time. CDSM allows the threshold-free prediction of the individual event time and early intervention on patients who might be prone to event occurrence.
This study demonstrated that injecting the knowledge of survival analysis into the design of recurrent neural network can significantly improve the prediction of time-to-event outcomes. Our proposed outcome model, CDSM fitting the joint distribution of both failure and censored observations. The conventional machine learning algorithms for binary discrimination can maximize evaluation scores such as AUROC, but failed to provide meaningful survival curves and reliable predictions of event time. The major drawback of these algorithms, as identified by our empirical study, is that they do not take account of censoring and had significant drop in accuracy when being evaluated on the testing database.
This work was supported by National Health and Medical Research Council, project grant no. 1125414.
- TensorFlow: Large-Scale Machine Learning on Heterogeneous Systems. Cited by: §3.1.
- Regression models and life tables (with discussion). J. Roy. Statist. Soc. Ser. B 34, pp. 187–220. Cited by: §1.
- Evaluating the yield of medical tests. Journal of the American Medical Association 247 (18), pp. 2543–2546. Cited by: §3.1.
- A scalable discrete-time survival model for neural networks. PeerJ 7, pp. e6257–e6257. External Links: Cited by: §1.
- Super Learner for Survival Data Prediction. The International Journal of Biostatistics 0 (0). External Links: Cited by: §1.
- Data Resource Profile: Clinical Practice Research Datalink (CPRD). International Journal of Epidemiology 44 (3), pp. 827–836. External Links: Cited by: item 3.
- Estimation of Heterogeneous Treatment Effects from Randomized Experiments, with Application to the Optimal Planning of the Get-Out-the-Vote Campaign. Political Analysis 19 (1), pp. 1–19. External Links: Cited by: §2.
- MIMIC-III, a freely accessible critical care database. Sci Data 3, pp. 160035–160035. Cited by: item 1.
- Nonparametric estimation from incomplete observations. Journal of the American statistical association 53 (282), pp. 457–481. Cited by: §1.
- DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC medical research methodology 18 (1), pp. 24. Cited by: §1, §3.
The Artificial Intelligence Clinician learns optimal treatment strategies for sepsis in intensive care. Nature Medicine 24 (11), pp. 1716–1720. External Links: Cited by: §3.
Dynamic-DeepHit: A Deep Learning Approach for Dynamic Survival Analysis With Competing Risks Based on Longitudinal Data. IEEE Transactions on Biomedical Engineering 67 (1), pp. 122–133. External Links: Cited by: §1.
- DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks.. In AAAI, pp. 2314–2321. Cited by: item 4.
- Wtte-rnn: Weibull time to event recurrent neural network. Cited by: §1.
- Fitting Survival Data to a Piecewise Linear Hazard Rate in the Presence of Covariates. Biometrical Journal. External Links: Cited by: §1.
- Early Prediction of Sepsis from Clinical Data – the PhysioNet Computing in Cardiology Challenge. Vol. (version 1.0.0). External Links: Cited by: §1, item 2, §3.
- Assessment of clinical criteria for sepsis: For the third international consensus definitions for sepsis and septic shock (sepsis-3). J. Am. Med. Assoc 315, pp. 762–774. Cited by: §3.
- Targeted Estimation of Heterogeneous Treatment Effect in Observational Survival Analysis. Journal of Biomedical Informatics, pp. 103474. External Links: Cited by: §2.