DeepSurv: Personalized Treatment Recommender System Using A Cox Proportional Hazards Deep Neural Network

06/02/2016 ∙ by Jared Katzman, et al. ∙ 0

Medical practitioners use survival models to explore and understand the relationships between patients' covariates (e.g. clinical and genetic features) and the effectiveness of various treatment options. Standard survival models like the linear Cox proportional hazards model require extensive feature engineering or prior medical knowledge to model treatment interaction at an individual level. While nonlinear survival methods, such as neural networks and survival forests, can inherently model these high-level interaction terms, they have yet to be shown as effective treatment recommender systems. We introduce DeepSurv, a Cox proportional hazards deep neural network and state-of-the-art survival method for modeling interactions between a patient's covariates and treatment effectiveness in order to provide personalized treatment recommendations. We perform a number of experiments training DeepSurv on simulated and real survival data. We demonstrate that DeepSurv performs as well as or better than other state-of-the-art survival models and validate that DeepSurv successfully models increasingly complex relationships between a patient's covariates and their risk of failure. We then show how DeepSurv models the relationship between a patient's features and effectiveness of different treatment options to show how DeepSurv can be used to provide individual treatment recommendations. Finally, we train DeepSurv on real clinical studies to demonstrate how it's personalized treatment recommendations would increase the survival time of a set of patients. The predictive and modeling capabilities of DeepSurv will enable medical researchers to use deep neural networks as a tool in their exploration, understanding, and prediction of the effects of a patient's characteristics on their risk of failure.



There are no comments yet.


page 6

page 7

Code Repositories


reproduce published statistical models

view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Medical researchers use survival models to evaluate the significance of prognostic variables in outcomes such as death or cancer recurrence and subsequently inform patients of their treatment options (YehJAMA, ; royston2013external, ; bair2004semi, ; cheng2013development, ). One standard survival model is the Cox proportional hazards model (CPH) (cox1992regression, ). The CPH is a semiparametric model that calculates the effects of observed covariates on the risk of an event occurring (e.g. ‘death’). The model assumes that a patient’s risk of failure is a linear combination of the patient’s covariates. This assumption is referred to as the linear proportional hazards condition. However, in many applications, such as providing personalized treatment recommendations, it may be too simplistic to assume that the risk function is linear. As such, a richer family of survival models is needed to better fit survival data with nonlinear risk functions.

To model nonlinear survival data, researchers have applied three main types of neural networks to the problem of survival analysis. These include variants of: (i) classification methods (see details in liestbl1994survival, ; street1998neural, ), (ii) time-encoded methods (see details in franco2005artificial, ; biganzoli1998feed, ), (iii) and risk-predicting methods (see details in faraggi1995neural, )

. This third type is a feed-forward neural network (NN) that estimates an individual’s risk of failure. In fact, Faraggi-Simon’s network is seen as a nonlinear extension of the Cox proportional hazards model.

Risk neural networks learn highly complex and nonlinear relationships between prognostic features and an individual’s risk of failure. In application, for example, when the success of a treatment option is affected by an individual’s features, the NN learns the relationship without prior feature selection or domain expertise. The network is then able to provide a personalized recommendation based on the computed risk of a treatment.

However, previous studies have demonstrated mixed results on NNs ability to predict risk. For instance, researchers have attempted to apply the Faraggi-Simon network with various extensions, but they have failed to demonstrate improvements beyond the linear Cox model, see sargent2001comparison , xiang2000comparison and mariani1997prognostic . One possible explanation is that the practice of NNs was not as developed as it is today. To the best of our knowledge, NNs have not outperformed standard methods for survival analysis (e.g. CPH). Our manuscript shows that this is no longer the case; with modern techniques, risk NNs have state-of-the-art performance and can be used for a variety of medical applications.

The goals of this paper are: (i) to show that the application of deep learning to survival analysis performs as well as or better than other survival methods in predicting risk; and (ii) to demonstrate that the deep neural network can be used as a personalized treatment recommender system and a useful framework for further medical research.

We propose a modern Cox proportional hazards deep neural network, henceforth referred to as DeepSurv, as the basis for a treatment recommender system. We make the following contributions. First, we show that DeepSurv performs as well as or better than other survival analysis methods on survival data with both linear and nonlinear risk functions. Second, we include an additional categorical variable representing a patient’s treatment group to illustrate how the network can learn complex relationships between an individual’s covariates and the effect of a treatment. Our experiments validate that the network successfully models the treatment’s risk within a population. Third, we use DeepSurv to provide treatment recommendations tailored to a patient’s observed features. We confirm our results on real clinical studies, which further demonstrates the power of DeepSurv. Finally, we show that the recommender system supports medical practitioners in providing personalized treatment recommendations that potentially could increase the median survival time for a set of patients.

The organization of the manuscript is as follows: in Section 2, we provide a brief background on survival analysis. In Section 3, we present our contributions, including an explanation of our implementation of DeepSurv and our proposed recommender system. In Section 4, we describe the experimental design and results. Section 5 concludes the manuscript.

2 Background

In this section, we define survival data and the approaches for modeling a population’s survival and failure rate. Additionally, we discuss linear and nonlinear survival models and their limitations.

2.1 Survival data

Survival data is comprised of three elements: a patient’s baseline data , a failure event time , and an event indicator . If an event (e.g. death) is observed, the time interval corresponds to the time elapsed between the time in which the baseline data was collected and the time of the event occurring, and the event indicator is . If an event is not observed, the time interval corresponds to the time elapsed between the collection of the baseline data and the last contact with the patient (e.g. end of study), and the event indicator is . In this case, the patient is said to be right-censored. If one opts to use standard regression methods, the right-censored data is considered to be a type of missing data. This is typically discarded which can introduce a bias in the model. Therefore, modeling right-censored data requires special consideration or the use of a survival model.

Survival and hazard functions are the two fundamental functions in survival analysis. The survival function is denoted by

, which signifies the probability that an individual has ‘survived’ beyond time

. The hazard function is a measure of risk at time . A greater hazard signifies a greater risk of death. The hazard function is defined as:


A proportional hazards model is a common method for modeling an individual’s survival given their baseline data . The model assumes that the hazard function is composed of two functions: a baseline hazard function, , and a risk function, , denoting the effects of an individual’s covariates. The hazard function is assumed to have the form .

2.2 Linear Survival Models

The CPH is a proportional hazards model that estimates the risk function by a linear function . To perform Cox regression, one tunes the weights to optimize the Cox partial likelihood. The partial likelihood is the product of the probability at each event time that the event has occurred to individual , given the set of individuals still at risk at time . The Cox partial likelihood is parameterized by and defined as


where the values , , and are the respective event time, event indicator, and baseline data for the observation. The product is defined over the set of patients with an observable event . The risk set is the set of patients still at risk of failure at time .

In many applications, for example modeling nonlinear gene interactions, we cannot assume the data satisfies the linear proportional hazards condition. In this case, the CPH model would require computing high-level interaction terms. This becomes prohibitively expensive as the number of features and interactions increases. Therefore, a more complex nonlinear model is needed.

2.3 Nonlinear survival models

The Faraggi-Simon method is a feed-forward neural network that provides the basis for a nonlinear proportional hazards model. faraggi1995neural experimented with a single hidden layer network with two or three nodes. Their model requires no prior assumption of the risk function other than continuity. Instead, the NN computes nonlinear features from the training data and calculates their linear combination to estimate the risk function. Similar to Cox regression, the network optimizes a modified Cox partial likelihood. They replace the linear combination of features in Equation 2 with the output of the network .

As previous research suggests, the Faraggi-Simon network has not been shown to outperform the linear CPH (faraggi1995neural, ; xiang2000comparison, ; mariani1997prognostic, )

. Furthermore, to the best of our knowledge, we were the first to attempt applying modern deep learning techniques to the Cox proportional hazards loss function.

Another popular machine learning approach to modeling patients’ risk function is the random survival forest (RSF)

(rsf2, ; rsf3, ). The random survival forest is a tree method that produces an ensemble estimate for the cumulative hazard function.

A more recent deep learning approach models the event time according to a Weibull distribution with parameters given by latent variables generated by a deep exponential family (bleiDSA, ).

3 Methods

In this section, we describe our methodology for providing personalized treatment recommendations using DeepSurv. First, we describe the architecture and training details of DeepSurv, an open source Python module that applies recent deep learning techniques to a nonlinear Cox proportional hazards network. Second, we define DeepSurv as a prognostic model and show how to use the network’s predicted risk function to provide personalized treatment recommendations.

3.1 DeepSurv

DeepSurv is a multi-layer perceptron, which predicts a patient’s risk of death. The output of the network is a single node, which estimates the risk function

parameterized by the weights of the network . Similar to the Faraggi-Simon network, we set the loss function to be the negative log partial likelihood of Equation 2:


We allow a deep architecture (i.e. more than one hidden layer) and apply modern techniques such as weight decay regularization, Rectified Linear Units (ReLU)

(nair2010rectified, )

with batch normalization

(ioffe2015batch, ), Scaled Exponential Linear Units (SELU) (gunter2017self, ), dropout (srivastava2014dropout, )

, gradient descent optimization algorithms (Stochastic Gradient Descent and Adaptive Moment Estimation (Adam)

(KingmaB14, )

), Nesterov momentum

(nesterov2013gradient, )

, gradient clipping

(pascanu2012understanding, ), and learning rate scheduling (senior2013empirical, ).

To tune the network’s hyper-parameters, we perform a Random hyper-parameter optimization search (bergstra2012random, ). For more technical details, see Appendix A.

3.2 Treatment recommender system

In a clinical study, patients are subject to different levels of risk based on their relevant prognostic features and which treatment they undergo. We generalize this assumption as follows. Let all patients in a given study be assigned to one of treatment groups . We assume each treatment to have an independent risk function . Collectively, the hazard function becomes:


For any patient, the network should be able to accurately predict the risk of being prescribed a given treatment . Then, based on the assumption that each individual has the same baseline hazard function , we can take the log of the hazards ratio to calculate the personal risk of prescribing one treatment option over another. We define this difference of log hazards as the recommender function or :


The recommender function can be used to provide personalized treatment recommendations. We first pass a patient through the network once in treatment group and again in treatment group and take the difference. When a patient receives a positive recommendation , treatment leads to a higher risk of death than treatment . Hence, the patient should be prescribed treatment . Conversely, a negative recommendation indicates that treatment is more effective and leads to a lower risk of death than treatment , and we recommend treatment .

DeepSurv’s architecture holds an advantage over the CPH because it calculates the recommender function without an a priori specification of treatment interaction terms. In contrast, the CPH model computes a constant recommender function unless treatment interaction terms are added to the model, see Appendix B for more details. Discovering relevant interaction terms is expensive because it requires extensive experimentation or prior biological knowledge of treatment outcomes. Therefore, DeepSurv is more cost-effective compared to CPH.

4 Results

We perform four sets of experiments: (i) simulated survival data, (ii) real survival data, (iii) simulated treatment data, and (iv) real treatment data. First, we use simulated data to show how DeepSurv successfully learns the true risk function of a population. Second, we validate the network’s predictive ability by training DeepSurv on real survival data. Third, we simulate treatment data to verify that the network models multiple risk functions in a population based on the specific treatment a patient undergoes. Fourth, we demonstrate how DeepSurv provides treatment recommendations and show that DeepSurv’s recommendations improve a population’s survival rate. For more technical details on the experiments, see Appendix A.

In addition to training DeepSurv on each dataset, we run a linear CPH regression for a baseline comparison. We also fit a RSF to compare DeepSurv against a state-of-the-art nonlinear survival model. Even though we can compare the RSF’s predictive accuracy to DeepSurv’s, we do not measure the RSF’s performance on modeling a simulated dataset’s true risk function . This is due to the fact that the the RSF calculates the cumulative hazard function rather than the hazard function .

4.1 Evaluation

Survival data

To evaluate the models’ predictive accuracy on the survival data, we measure the concordance-index (C-index) as outlined by harrell1984regression . The C-index is the most common metric used in survival analysis and reflects a measure of how well a model predicts the ordering of patients’ death times. For context, a is the average C-index of a random model, whereas is a perfect ranking of death times. We perform bootstrapping (efron1994introduction, )

and sample the test set with replacement to obtain confidence intervals.

Experiment CPH DeepSurv RSF
Simulated Linear


Simulated Nonlinear

(, )

(, )






Simulated Treatment


Rotterdam & GBSG
(, )


(, )
Table 1: Experimental Results for All Experiments: C-index (95% Confidence Interval)

Treatment recommendations

We determine the recommended treatment for each patient in the test set using DeepSurv and the RSF. We do not calculate the recommended treatment for CPH; without preselected treatment-interaction terms, the CPH model will compute a constant recommender function and recommend the same treatment option for all patients. This would effectively be comparing the survival rates between the control and experimental groups. DeepSurv and the RSF are capable of predicting an individual’s risk per treatment because each computes relevant interaction terms. For DeepSurv, we choose the recommended treatment by calculating the recommender function (Equation 5). Because the RSF predicts a cumulative hazard for each patient, we choose the treatment with the minimum cumulative hazard.

Once we determine the recommended treatment, we identify two subsets of patients: those whose treatment group aligns with the model’s recommended treatment (Recommendation) and those who do not undergo the recommended treatment (Anti-Recommendation). We calculate the median survival time of each subset to determine if a model’s treatment recommendations increase the survival rate of the patients. We then perform a log-rank test to validate whether the difference between the two subsets is significant.

Experiment DeepSurv RSF
Rec Anti-Rec Rec Anti-Rec
Rotterdam & GBSG
Table 2: Experimental Results for Treatment Recommendations: Median Survival Time (months)

4.2 Simulated survival data

In this section, we perform two experiments with simulated survival data: one with a linear risk function and one with a nonlinear (Gaussian) risk function. The advantage of using simulated datasets is that we can ascertain whether DeepSurv can successfully model the true risk function instead of overfitting random noise.

For each experiment, we generate a training, validation, and testing set of

observations respectively. Each observation represents a patient vector with

covariates, each drawn from a uniform distribution on

. We generate the death time according to an exponential Cox model (austin2012generating, ):


Details of the simulated data generation are found in Appendix C.

In both experiments, the risk function only depends on two of the ten covariates, and we demonstrate that DeepSurv discerns the relevant covariates from the noise. We then choose a censoring time to represent the ‘end of study’ such that about 90 percent of the patients have an observed event in the dataset.

4.2.1 Linear risk experiment

We first simulate patients to have a linear risk function for so that the linear proportional hazards assumption holds true:


Because the linear proportional hazards assumption holds true, we expect the linear CPH to accurately model the risk function in Equation 7.

Our results (see Table 1) demonstrate that DeepSurv performs as well as the standard linear Cox regression and better than RSF in predictive ability.

(a) True
(b) CPH
(c) DeepSurv
Figure 1: Predicted risk surfaces and errors for the simulated survival data with linear risk function with respect to a patient’s covariates and . 0(a) The true risk for each patient. 0(b) The predicted risk surface of from the linear CPH model parameterized by . 0(c) The output of DeepSurv predicts a patient’s risk. 0(d) The absolute error between true risk and CPH’s predicted risk . 0(e) The absolute error between true risk and DeepSurv’s predicted risk .

Figure 1 demonstrates how DeepSurv more accurately models the risk function compared to the linear CPH. Figure 0(a) plots the true risk function for all patients in the test set. As shown in Figure 0(b), the CPH’s estimated risk function does not perfectly model the true risk for a patient. In contrast, as shown in Figure 0(c), DeepSurv better estimates the true risk function.

To quantify these differences, Figures 0(d) and 0(e) show that the CPH’s estimated risk has a significantly larger absolute error than that of DeepSurv, specifically for patients with a high positive risk. We calculate the mean-squared-error (MSE) between a model’s predicted risk and the true risk values. The MSEs of CPH and DeepSurv are and , respectively. Even though DeepSurv and CPH have similar predictive abilities, this demonstrates that DeepSurv is superior than the CPH at modeling the true risk function of the population.

4.2.2 Nonlinear risk experiment

We set the risk function to be a Gaussian with and a scale factor of :


The surface of the risk function is depicted in 1(a). Because this risk function is nonlinear, we do not expect the CPH to predict the risk function properly without adding quadratic terms of the covariates to the model. We expect DeepSurv to reconstruct the Gaussian risk function and successfully predict a patient’s risk. Lastly, we expect the RSF and DeepSurv to accurately rank the order of patient’s deaths.

The CI results in Table 1 shows that DeepSurv outperforms the linear CPH and predicts as well as the RSF. In addition, DeepSurv correctly learns nonlinear relationships between a patient’s covariates and their risk. As shown in Figure 2, DeepSurv is more successful than the linear CPH in modeling the true risk function. Figure 1(b) demonstrates that the linear CPH regression fails to determine the first two covariates as significant. The CPH has a C-index of , which is equivalent to the performance of randomly ranking death times. Meanwhile, Figure 1(c) demonstrates that DeepSurv reconstructs the Gaussian relationship between the first two covariates and a patient’s risk.

(a) True
(b) CPH
(c) DeepSurv
Figure 2: Risk surfaces of the nonlinear test set with respect to patient’s covariates and . 1(a) The calculated true risk (Equation 8) for each patient. 1(b) The predicted risk surface of from the linear CPH model parameterized on . The linear CPH predicts a constant risk. 1(c) The output of DeepSurv is the estimated risk function.

4.3 Real survival data experiments

We compare the performance of the CPH and DeepSurv on three datasets from real studies: the Worcester Heart Attack Study (WHAS), the Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT), and The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC). Because previous research shows that neural networks do not outperform the CPH, our goal is to demonstrate that DeepSurv does indeed have state-of-the-art predictive ability in practice on real survival datasets.

4.3.1 Worcester Heart Attack Study (WHAS)

The Worcester Heart Attack Study (WHAS) investigates the effects of a patient’s factors on acute myocardial infraction (MI) survival (hosmer9780471754992, ). The dataset consists of 1,638 observations and 5 features: age, sex, body-mass-index (BMI), left heart failure complications (CHF), and order of MI (MIORD). We reserve 20 percent of the dataset as a testing set. A total of 42.12 percent of patients died during the survey with a median death time of 516.0 days. As shown in Table 1, DeepSurv outperforms the CPH; however, the RSF outperforms DeepSurv.

4.4 Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT)

The Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT) is a larger study that researches the survival time of seriously ill hospitalized adults (knaus1995support, ). The dataset consists of 9,105 patients and 14 features for which almost all patients have observed entries (age, sex, race, number of comorbidities, presence of diabetes, presence of dementia, presence of cancer, mean arterial blood pressure, heart rate, respiration rate, temperature, white blood cell count, serum’s sodium, and serum’s creatinine). We drop patients with any missing features and reserve 20 percent of the dataset as a testing set. A total of 68.10 percent of patients died during the survey with a median death time of 58 days.

As shown in Table 1, DeepSurv performs as well as the RSF and better than the CPH with a larger study. This validates DeepSurv’s ability to predict the ranking of patient’s risks on real survival data.

4.4.1 Molecular Taxonomy of Breast Cancer International Consortium (METABRIC)

The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC) uses gene and protein expression profiles to determine new breast cancer subgroups in order to help physicians provide better treatment recommendations.

The METABRIC dataset consists of gene expression data and clinical features for 1,980 patients, and 57.72 percent have an observed death due to breast cancer with a median survival time of 116 months (curtis2012genomic, ). We prepare the dataset in line with the Immunohistochemical 4 plus Clinical (IHC4+C) test, which is a common prognostic tool for evaluating treatment options for breast cancer patients (lakhanpal2016ihc4, ). We join the 4 gene indicators (MKI67, EGFR, PGR, and ERBB2) with the a patient’s clinical features (hormone treatment indicator, radiotherapy indicator, chemotherapy indicator, ER-positive indicator, age at diagnosis). We then reserved 20 percent of the patients as the test set.

Table 1 shows that DeepSurv performs better than both the CPH and RSF. This result demonstrates not only DeepSurv’s ability to model the risk effects of gene expression data but also shows the potential for future research of DeepSurv as a comparable prognostic tool to common medical tests such as the IHC4+C.

4.5 Treatment recommender system experiments

In this section, we perform two experiments to demonstrate the effectiveness of DeepSurv’s treatment recommender system. First, we simulate treatment data by including an additional covariate to the simulated data from Section 4.2.2. Second, after demonstrating DeepSurv’s modeling and recommendation capabilities, we apply the recommender system to a real dataset used to study the effects of hormone treatment on breast cancer patients. We show that DeepSurv can successfully provide personalized treatment recommendations. We conclude that if all patients follow the network’s recommended treatment options, we would gain a significant increase in patients’ lifespans.

4.5.1 Simulated treatment data

We uniformly assign a treatment group to each simulated patient in the dataset. All of the patients in group were ‘unaffected’ by the treatment (e.g. given a placebo) and have a constant risk function . The other group is prescribed a treatment with Gaussian effects (Equation 8) and has a risk function with and .

Figure 3 illustrates the network’s success in modeling both treatments’ risk functions for patients. Figure 2(a) plots the true risk distribution . As expected, Figure 2(b) shows that the network models a constant risk for a patient in treatment , independent of a patient’s covariates. Figure 2(c) shows how DeepSurv models the Gaussian effects of a patient’s covariates on their treatment risk. To further quantify these results, Table 1 shows that DeepSurv has the largest concordance index. Because the network accurately reconstructs the risk function, we expect that it will provide accurate treatment recommendations for new patients.

(a) True
(b) DeepSurv
(c) DeepSurv
Figure 3: Treatment Risk Surfaces as a function of a patient’s relevant covariates and . 2(a) The true risk if all patients in the test set were given treatment . We then manually set all treatment groups to either or . 2(b) The predicted risk for patients with treatment group . 2(c) The network’s predicted risk for patients in treatment group .
(a) Effect of DeepSurv’s Treatment Recommendations (Simulated Data)
(b) Effect of RSF’s Treatment Recommendations (Simulated Data)
Figure 4: Kaplan-Meier estimated survival curves with confidence intervals () for the patients who were given the treatment concordant with a method’s recommended treatment (Recommendation) and the subset of patients who were not (Anti-Recommendation). We perform a log-rank test to validate the significance between each set of survival curves.

In Figure 4, we plot the Kaplan-Meier survival curves for both the Recommendation and Anti-Recommendation subset for each method. Figure 3(a) shows that the survival curve for the Recommendation subset is shifted to the right, which signifies an increase in survival time for the population following DeepSurv’s recommendations. This is further quantified by the median survival times summarized in Table 2. The p-value of DeepSurv’s recommendations is less than 1

, and we can reject the null hypothesis that DeepSurv’s recommendations would not affect the population’s survival time. As shown in Table

2, the subset of patients that follow RSF’s recommendations have a shorter survival time than those who do not follow RSF’s recommended treatment. Therefore, we could take the RSF’s recommendations and provide the patients with the opposite treatment option to increase median survival time; however, Figure 3(b) shows that that improvement would not be statistically valid. While both methods of DeepSurv and RSF are able to compute treatment interaction terms, DeepSurv is more successful in recommending personalized treatments.

4.5.2 Rotterdam & German Breast Cancer Study Group (GBSG)

We first train DeepSurv on breast cancer data from the Rotterdam tumor bank (foekens2000urokinase, ). and construct a recommender system to provide treatment recommendations to patients from a study by the German Breast Cancer Study Group (GBSG) (schumacher1994randomized, ). The Rotterdam tumor bank dataset contains records for 1,546 patients with node-positive breast cancer, and nearly 90 percent of the patients have an observed death time. The testing data from the GBSG contains complete data for 686 patients (56 percent are censored) in a randomized clinical trial that studied the effects of chemotherapy and hormone treatment on survival rate. We preprocess the data as outlined by altman2000we .

We first validate DeepSurv’s performance against the RSF and CPH baseline. We then plot the two survival curves: the survival times of those who followed the recommended treatment and those who did not. If the recommender system is effective, we expect the population with the recommended treatments to survive longer than those who did not take the recommended treatment.

(a) Effect of DeepSurv’s Treatment Recommendations (GBSG)
(b) Effect of RSF’s Treatment Recommendations (GBSG)
Figure 5: Kaplan-Meier estimated survival curves with confidence intervals () for the patients who were given the treatment concordant with a method’s recommended treatment (Recommendation) and the subset of patients who were not (Anti-Recommendation). We perform a log-rank test to validate the significance between each set of survival curves.

Table 1 shows that DeepSurv provides an improved predictive ability relative to the CPH and RSF. In Figure 5, we plot the Kaplan-Meier survival curves for both the Recommendation subset and the Anti-Recommendation subset for each method. Figure 4(a) shows that the survival curve for DeepSurv’s Recommendation subset is statistically significant from the Anti-recommendation subset, and Table 2 shows that DeepSurv’s recommendations increase the median survival time of the population. Figure 4(b) demonstrates that RSF is unable to provide significant treatment recommendations, despite an increase in median survival times (see Table 2). The results of this experiment demonstrate not only DeepSurv’s superior modeling capabilities but also validate DeepSurv’s ability in providing personalized treatment recommendations on real clinical data. Moreover, we can train DeepSurv on survival data from one clinical study and transfer the learnings to provide personalized treatment recommendations to a different population of breast cancer patients.

5 Conclusion

In conclusion, we demonstrated that the use of deep learning in survival analysis allows for: (i) higher performance due to the flexibility of the model, and (ii) effective treatment recommendations based on the predicted effect of treatment options on an individual’s risk. We validated that DeepSurv predicts patients’ risk mostly as well as or better than other linear and nonlinear survival methods. We experimented on increasingly complex survival datasets and demonstrated that DeepSurv computes complex and nonlinear features without a priori selection or domain expertise. We then demonstrated that DeepSurv is superior in predicting personalized treatment recommendations compared to the state-of-the-art survival method of random survival forests. We also released a Python module that implements DeepSurv and scripts for running reproducible experiments in Docker, see

for more details. The success of DeepSurv’s predictive, modeling, and recommending abilities paves the way for future research in deep neural networks and survival analysis. DeepSurv can lead to various extensions, such as the use of convolution neural networks to predict risk with medical imaging. With more research at scale, DeepSurv has the potential to supplement traditional survival analysis methods and become a standard method for medical practitioners to study and recommend personalized treatment options.


This research was partially funded by a National Institutes of Health grant [1R01HG008383-01A1 to Y.K.] and supported by a National Science Foundation Award [DMS-1402254 to A.C.].


Appendix A Experimental Details

We run all linear CPH regression, Kaplan-Meier estimations, c-index statistics, and log-rank tests using the Lifelines Python package. DeepSurv is implemented in Theano with the Python package Lasagne. We use the R package randomForestSRC to fit RSFs. All experiments are run using Docker containers such that the experiments are easily reproducible. We use the FloydHub base image for the DeepSurv docker container.

The hyper-parameters of the network include: the depth and size of the network, learning rate, regularization coefficient, dropout rate, exponential learning rate decay constant , and momentum. We run the Random hyper-parameter optimization search as proposed in (bergstra2012random, ) using the Python package Optunity. We use the Sobol solver (sobol1976uniformly, ; Fox:1986:AIR:22721.356187, ) to sample each hyper-parameter from a predefined range and evaluate the performance of the configuration using -means cross validation (). We then choose the configuration with the largest validation C-index to avoid models that overfit. The hyper-parameters we use in all experiments are summarized in Appendix A.1.

a.1 Model Hyper-parameters

We tune DeepSurv’s hyper-parameters by running a random hyper-parameter search using the Python package Optunity. The table below summarizes the hyper-parameters we use for each experiment’s DeepSurv network.

center Hyper-parameter Sim Linear Sim Nonlinear WHAS SUPPORT METABRIC Sim Treatment GBSG Optimizer sgd sgd adam adam adam adam adam Activation SELU ReLU ReLU SELU SELU SELU SELU # Dense Layers # Nodes / Layer Learning Rate (LR) Reg Dropout LR Decay Momentum

Table 3: DeepSurv’s Experimental Hyper-parameters

We applied inverse time decay to the learning rate at each epoch:


Appendix B CPH Recommender Function

Let each patient in the dataset have a set of features , in which one feature is a treatment variable . The CPH model estimates the risk function as a linear combination of the patient’s features . When we calculate the recommender function for the CPH model, we show that the model returns a constant function independent of the patient’s features:


The CPH will recommend all patients to choose the same treatment option based on whether the model calculates the weight to be positive or negative. Thus, the CPH would not be providing personalized treatment recommendations. Instead, the CPH determines whether the treatment is effective and, if so, then recommending it to all patients. In an experiment, when we calculate which patients took the CPH’s recommendation, the Recommendation and Anti-Recommendation subgroups will be equal to the control and treatment groups. Therefore, calculating treatment recommendations using the CPH provides little value to the experiments in terms of comparing the models’ recommendations.

Appendix C Simulated Data Generation

Each patient’s baseline information is drawn from a uniform distribution on . For datasets that also involve treatment, the patient’s treatment status

is drawn from a Bernoulli distribution with


The Cox proportional hazard model assumes that the baseline hazard function

is shared across all patients. The initial death time is generated according to an exponential random variable with a mean

, which we denote . The individual death time is then generated by

These times are then right censored at an end time to represent the end of a trial. The end time is chosen such that 90 percent of people have an observed death time.

Because we cannot observe any beyond the end time threshold, we denote the final observed outcome time


  • [1] Yeh RW, Secemsky EA, Kereiakes DJ, and et al. Development and validation of a prediction rule for benefit and harm of dual antiplatelet therapy beyond 1 year after percutaneous coronary intervention. JAMA, 315(16):1735–1749, 2016.
  • [2] Patrick Royston and Douglas G Altman. External validation of a cox prognostic model: principles and methods. BMC medical research methodology, 13(1):1, 2013.
  • [3] Eric Bair and Robert Tibshirani. Semi-supervised methods to predict patient survival from gene expression data. PLoS Biol, 2(4):e108, 2004.
  • [4] Wei-Yi Cheng, Tai-Hsien Ou Yang, and Dimitris Anastassiou. Development of a prognostic model for breast cancer survival in an open challenge environment. Science translational medicine, 5(181):181ra50–181ra50, 2013.
  • [5] David R Cox. Regression models and life-tables. In Breakthroughs in statistics. Springer, 1992.
  • [6] Knut Liestbl, Per Kragh Andersen, and Ulrich Andersen. Survival analysis and neural nets. Statistics in medicine, 13(12):1189–1200, 1994.
  • [7] W Nick Street. A neural network model for prognostic prediction. In ICML, pages 540–546, 1998.
  • [8] Leonardo Franco, José M Jerez, and Emilio Alba. Artificial neural networks and prognosis in medicine. survival analysis in breast cancer patients. In ESANN, pages 91–102. i6doc, 2005.
  • [9] Elia Biganzoli, Patrizia Boracchi, Luigi Mariani, and Ettore Marubini.

    Feed forward neural networks for the analysis of censored survival data: a partial logistic regression approach.

    Statistics in medicine, 17(10):1169–1186, 1998.
  • [10] David Faraggi and Richard Simon. A neural network model for survival data. Statistics in medicine, 14(1):73–82, 1995.
  • [11] Daniel J Sargent. Comparison of artificial neural networks with other statistical approaches. Cancer, 91(S8):1636–1642, 2001.
  • [12] Anny Xiang, Pablo Lapuerta, Alex Ryutov, Jonathan Buckley, and Stanley Azen. Comparison of the performance of neural network methods and cox regression for censored survival data. Computational statistics & data analysis, 34(2):243–257, 2000.
  • [13] L Mariani, D Coradini, E Biganzoli, P Boracchi, E Marubini, S Pilotti, B Salvadori, R Silvestrini, U Veronesi, R Zucali, et al. Prognostic factors for metachronous contralateral breast cancer: a comparison of the linear cox regression model and its artificial neural network extension. Breast cancer research and treatment, 44(2):167–178, 1997.
  • [14] H. Ishwaran and U.B. Kogalur. Random survival forests for r. R News, 7(2):25–31, October 2007.
  • [15] H. Ishwaran, U.B. Kogalur, E.H. Blackstone, and M.S. Lauer. Random survival forests. Ann. Appl. Statist., 2(3):841–860, 2008.
  • [16] Rajesh Ranganath, Adler Perotte, No mie Elhadad, and David Blei. Deep survival analysis. In Finale Doshi-Velez, Jim Fackler, David Kale, Byron Wallace, and Jenna Weins, editors, Proceedings of the 1st Machine Learning for Healthcare Conference, volume 56 of Proceedings of Machine Learning Research, pages 101–114, Northeastern University, Boston, MA, USA, 18–19 Aug 2016. PMLR.
  • [17] Vinod Nair and Geoffrey E. Hinton.

    Rectified linear units improve restricted boltzmann machines.

    In Johannes F rnkranz and Thorsten Joachims, editors, Proceedings of the 27th International Conference on Machine Learning (ICML-10), pages 807–814. Omnipress, 2010.
  • [18] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing internal covariate shift. In International Conference on Machine Learning, pages 448–456, 2015.
  • [19] G nter Klambauer, Thomas Unterthiner, Andreas Mayr, and Sepp Hochreiter. Self-normalizing neural networks. arXiv preprint, jun 2017.
  • [20] Nitish Srivastava, Geoffrey Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. Dr opout: A simple way to prevent neural networks from overfitting. The Journal of Machine Learning Research, 15(1):1929–1958, 2014.
  • [21] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.
  • [22] Yu Nesterov. Gradient methods for minimizing composite functions. Mathematical Programming, 140(1):125–161, 2013.
  • [23] Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. Understanding the exploding gradient problem. Computing Research Repository (CoRR) abs/1211.5063, 2012.
  • [24] Alan Senior, Georg Heigold, Marc’Aurelio Ranzato, and Ke Yang. An empirical study of learning rates in deep neural networks for speech recognition. In Acoustics, Speech and Signal Processing (ICASSP), 2013 IEEE International Conference on, pages 6724–6728. IEEE, 2013.
  • [25] James Bergstra and Yoshua Bengio. Random search for hyper-parameter optimization. The Journal of Machine Learning Research, 13(1):281–305, 2012.
  • [26] Frank E Harrell, Kerry L Lee, Robert M Califf, David B Pryor, and Robert A Rosati. Regression modeling strategies for improved prognostic prediction. Statistics in medicine, 3(2):143–152, 1984.
  • [27] Bradley Efron and Robert J Tibshirani. An introduction to the bootstrap. CRC press, 1994.
  • [28] Peter C Austin. Generating survival times to simulate cox proportional hazards models with time-varying covariates. Statistics in medicine, 31(29):3946–3958, 2012.
  • [29] David W. Hosmer Jr., Stanley Lemeshow, and Susanne May. Applied Survival Analysis: Regression Modeling of Time to Event Data. Wiley-Interscience, 2008.
  • [30] William A Knaus, Frank E Harrell, Joanne Lynn, Lee Goldman, Russell S Phillips, Alfred F Connors, Neal V Dawson, William J Fulkerson, Robert M Califf, Norman Desbiens, et al. The support prognostic model: objective estimates of survival for seriously ill hospitalized adults. Annals of internal medicine, 122(3):191–203, 1995.
  • [31] Christina Curtis, Sohrab P Shah, Suet-Feung Chin, Gulisa Turashvili, Oscar M Rueda, Mark J Dunning, Doug Speed, Andy G Lynch, Shamith Samarajiwa, Yinyin Yuan, et al. The genomic and transcriptomic architecture of 2,000 breast tumours reveals novel subgroups. Nature, 486(7403):346–352, 2012.
  • [32] Roopa Lakhanpal, Ivana Sestak, Bruce Shadbolt, Genevieve M Bennett, Michael Brown, Tessa Phillips, Yanping Zhang, Amanda Bullman, and Angela Rezo. Ihc4 score plus clinical treatment score predicts locoregional recurrence in early breast cancer. The Breast, 29:147–152, 2016.
  • [33] John A Foekens, Harry A Peters, Maxime P Look, Henk Portengen, Manfred Schmitt, Michael D Kramer, Nils Brünner, Fritz Jänicke, Marion E Meijer-van Gelder, Sonja C Henzen-Logmans, et al. The urokinase system of plasminogen activation and prognosis in 2780 breast cancer patients. Cancer research, 60(3):636–643, 2000.
  • [34] M Schumacher, G Bastert, H Bojar, K Huebner, M Olschewski, W Sauerbrei, C Schmoor, C Beyerle, RL Neumann, and HF Rauschecker. Randomized 2 x 2 trial evaluating hormonal treatment and the duration of chemotherapy in node-positive breast cancer patients. german breast cancer study group. Journal of Clinical Oncology, 12(10):2086–2093, 1994.
  • [35] Douglas G Altman and Patrick Royston. What do we mean by validating a prognostic model? Statistics in medicine, 19(4):453–473, 2000.
  • [36] Ilya M Sobol. Uniformly distributed sequences with an additional uniform property. USSR Computational Mathematics and Mathematical Physics, 16(5):236–242, 1976.
  • [37] Bennett L. Fox. Algorithm 647: Implementation and relative efficiency of quasirandom sequence generators. ACM Trans. Math. Softw., 12(4):362–376, December 1986.