An implementation of the fine-grained survival analysis with deep recurrent neural network.
Survival analysis is a hotspot in statistical research for modeling time-to-event information with data censorship handling, which has been widely used in many applications such as clinical research, information system and other fields with survivorship bias. Many works have been proposed for survival analysis ranging from traditional statistic methods to machine learning models. However, the existing methodologies either utilize counting-based statistics on the segmented data, or have a pre-assumption on the event probability distribution w.r.t. time. Moreover, few works consider sequential patterns within the feature space. In this paper, we propose a Deep Recurrent Survival Analysis model which combines deep learning for conditional probability prediction at fine-grained level of the data, and survival analysis for tackling the censorship. By capturing the time dependency through modeling the conditional probability of the event for each sample, our method predicts the likelihood of the true event occurrence and estimates the survival rate over time, i.e., the probability of the non-occurrence of the event, for the censored data. Meanwhile, without assuming any specific form of the event probability distribution, our model shows great advantages over the previous works on fitting various sophisticated data distributions. In the experiments on the three real-world tasks from different fields, our model significantly outperforms the state-of-the-art solutions under various metrics.READ FULL TEXT VIEW PDF
The emergence of real-time auction in online advertising has drawn huge
Accurately predicting the time of occurrence of an event of interest is ...
Within the Private Equity (PE) market, the event of a private company
Survival analysis is a type of semi-supervised ranking task where the ta...
The modeling of time-to-event data, also known as survival analysis, req...
This paper introduces link functions for transforming one probability
Conventional survival analysis approaches estimate risk scores or
An implementation of the fine-grained survival analysis with deep recurrent neural network.
Recent advances of modern technology makes redundant data collection available for time-to-event information, which facilitates observing and tracking the event of interests. However, due to different reasons, many events would lose tracking during observation period, which makes the data censored. We only know that the true time to the occurrence of the event is larger or smaller than, or within the observation time, which have been defined as survivorship bias categorized into right-censored, left-censored and internal-censored respectively [Lee and Wang2003]. Survival analysis, a.k.a. time-to-event analysis [Lee et al.2018], is a typical statistical methodology for modeling time-to-event data while handling censorship, which is a traditional research problem and has been studied over decades.
The goal of survival analysis is to estimate the time until occurrence of the particular event of interest, which can be regarded as a regression problem [Lee and Wang2003, Wu, Yeh, and Chen2015]. It can also be viewed as to predict the probability of the event occurring over the whole timeline [Wang et al.2016, Lee et al.2018]. Specifically, given the information of the observing object, survival analysis would predict the probability of the event occurrence at each time point.
Nowadays, survival analysis has been widely used in real-world applications, such as clinical analysis in medicine research [Zhu et al.2017b, Luck et al.2017, Katzman et al.2018] taking diseases as events and predicting survival time of patients; customer lifetime estimation in information systems [Jing and Smola2017, Grob et al.2018]
which estimates the time until the next visit of users; market modeling in game theory fields[Wu, Yeh, and Chen2015, Wang et al.2016] that predicts the event (i.e., winning) probability over the whole referral space.
Because of the essential applications in the real world, the researchers in both academic and industrial fields have devoted great efforts to studying survival analysis in recent decades. Many works of survival analysis are from the view of traditional statistic methodology. Among them, Kaplan-Meier estimator [Kaplan and Meier1958] bases on non-parametric counting statistics and forecasts the survival rate at coarse-grained level where different observing objects may share the same forecasting result, which is not suitable in recent personalized applications. Cox proportional hazard method [Cox1992] and its variants such as Lasso-Cox [Tibshirani1997]
assume specific stochastic process or base distribution with semi-parametric scaling coefficients for fine-tuning the final survival rate prediction. Other parametric methods either make specific distributional assumptions, such as Exponential distribution[Lee and Wang2003] and Weibull distribution [Ranganath et al.2016]. These methods pre-assume distributional forms for the survival rate function, which may not generalize very well in real-world situations.
Recently, deep learning, i.e., deep neural network, has been paid huge attention and introduced to survival analysis in many tasks[Ranganath et al.2016, Grob et al.2018, Lee et al.2018]. However, in fact, many deep learning models for survival analysis [Katzman et al.2018, Ranganath et al.2016]
actually utilize deep neural network as the enhanced feature extraction method[Lao et al.2017, Grob et al.2018] and, worse still, rely on some assumptions of the base distributions for the survival rate prediction, which also suffers from the generalization problem. Lately, lee2018deephit lee2018deephit proposed a deep learning method for modeling the event probability without assumptions of the probability distribution. Nevertheless, they regard the event probability estimation as a pointwise prediction problem, and ignores the sequential patterns within neighboring time slices. Moreover, the gradient signal is too sparse and has little effect on most of the prediction outputs of this model, which is not effective enough for modeling time-to-event data.
With the consideration of all the drawbacks within the existing literatures, in this paper we propose our Deep Recurrent Survival Analysis (DRSA) model for predicting the survival rate over time at fine-grained level, i.e., for each individual sample. To the best of our knowledge, this is the first work utilizing auto-regressive model for capturing the sequential patterns of the feature over time in survival analysis.
Our model proposes a novel modeling view for time-to-event data, which aims at flexibly modeling the survival probability function rather than making any assumptions for the distribution form. Specifically, DRSA creatively predicts the conditional probability of the event at each time given that the event non
-occurred before, and combines them through probability chain rule for estimating both the probability density function and the cumulative distribution function of the event over time, eventually forecasts the survival rate at each time, which is more reasonable and mathematically efficient for survival analysis. We train DRSA model by end-to-end optimization through maximum likelihood estimation, not only on the observed event among uncensored data, but also on the censored samples to reduce the survivorship bias. Through these modeling methods, our DRSA model can capture the sequential patterns embedded in the feature space along the time, and output more effective distributions for each individual sample at fine-grained level. The comprehensive experiments over three large-scale real-world datasets demonstrate that our model achieves significant improvements against state-of-the-art models under various metrics.
The event occurrence information of some samples may be lost, due to some limitation of the observation period or losing tracks during the study procedure [Wang, Li, and Reddy2017], which is called data censorship. When dealing with time-to-event information, a more complex learning problem is to estimate the probability of the event occurrence at each time, especially for those samples without tracking logs after (or before) the observation time which is defined as right-censored (or left-censored) [Wang, Li, and Reddy2017]. Survival analysis is a typical statistical methodology for modeling time-to-event data while handling censorship. There are two main streams of survival analysis.
The first view is based on traditional statistics scattering in three categories. (i) Non-parametric methods including Kaplan-Meier estimator [Kaplan and Meier1958] and Nelson-Aalen estimator [Andersen et al.2012] are solely based on counting statistics, which is too coarse-grained to perform personalized modeling. (ii) Semi-parametric methods such as Cox proportional hazard model [Cox1992] and its variants Lasso-Cox [Tibshirani1997]
assumes some base distribution functions with the scaling coefficients for fine-tuning the final survival rate prediction. (iii) Parametric models assume that the survival time or its logarithm result follows a particular theoretical distribution such as Exponential distribution[Lee and Wang2003] and Weibull distribution [Ranganath et al.2016]. These methods either base on statistical counting information or pre-assume distributional forms for the survival rate function, which generalizes not very well in real-world situations.
The second school of survival analysis takes from machine learning perspective. Survival random forest which was first proposed in[Gordon and Olshen1985]
derives from standard decision tree by modeling the censored data[Wang et al.2016] while its idea is mainly based on counting-based statistics. Other machine learning methodologies include Bayesian models [Ranganath et al.2015, Ranganath et al.2016]Khan and Zubek2008] and multi-task learning solutions [Li et al.2016, Alaa and van der Schaar2017]. Note that, deep learning models have emerged in recent years. faraggi1995neural faraggi1995neural first embedded neural network into Cox model to improve covariate relationship modeling. From that, many works applied deep neural networks into well-studied statistical models to improve feature extraction and survival analysis through end-to-end learning, such as [Ranganath et al.2016, Luck et al.2017, Lao et al.2017, Katzman et al.2018, Grob et al.2018]. Almost all the above models assume particular distribution forms which also suffers from the generalization problem in practice.
biganzoli1998feed biganzoli1998feed utilized neural network directly to predict the survival rate for each sample and lisboa2003bayesian lisboa2003bayesian extended it to a Bayesian network method. In[Lee et al.2018]
the authors proposed a feed forward deep model to directly predict the probability density values at each time point and sum them for estimating the survival rate. However, in that paper, the gradient signal is quite sparse for the prediction outputs from the neural network. Moreover, to our best knowledge, none of the related literatures considers the sequential patterns within the feature space over time. We propose a recurrent neural network model predicting theconditional probability of event at each time and estimate the survival rate through the probability chain rule, which captures the sequential dependency patterns between neighboring time slices and back-propagate the gradient more efficiently.
Due to its adequate model capability and the support of big data, deep learning, a.k.a. deep neural network, has drawn great attention ranging from computer vision[Krizhevsky, Sutskever, and Hinton2012] and speech recognition [Graves, Mohamed, and Hinton2013]Bahdanau, Cho, and Bengio2015]
during the recent decades. Among them, recurrent neural network (RNN) whose idea firstly emerged two decades ago and its variants like long short-term memory (LSTM)[Hochreiter and Schmidhuber1997] employ memory structures to model the conditional probability which captures dynamic sequential patterns. In this paper we borrow the idea of RNN and well design the modeling methodology for survival function regression.
In this section, we formulate the survival analysis problem and discuss the details of our proposed model. We take the view of right-censorship which is the most common scenario in survival analysis [Kaplan and Meier1958, Cox1992, Wang, Li, and Reddy2017, Lee et al.2018].
We define as the variable of the true occurrence time for the event of interest if it has been tracked. We just simplify the occurrence of the event of interest as event, and define the probabilistic density function (P.D.F.) of the true event time as , which means the probability that the event truely occurs at time .
Now that we have the P.D.F. of the event time, we can derive the survival rate at each time as the C.D.F. as
which is the probability of the observing object surviving, i.e., event not occurring, until the observed time . Then the straightforward definition of the event rate, i.e., the probability of event occurring before the observing time , as that
The data of the survival analysis logs are represented as a set of triples , where is the observation time for the given sample. Here is left unknown (and we marked as null) for the censored samples without the observation of the true event time. is the feature vector of the observation which encodes different information under various scenarios.
Our goal is to model the distribution of the true event time over all the historical time-to-event logs with handling the censored data of which the true event time is unknown. So the main problem of survival analysis is to estimate the probability distribution of the event time with regard to the sample feature , for each sample. Formally speaking, the derived model is a “mapping” function which learns the patterns within the data and predicts the event time distribution over the time space as
First of all, we present the definition of the conditional hazard rate over continuous time as
which models the instant occurrence probability of the event at time given that the event has not occurred before. Note that the concept of hazard rate has been commonly utilized in many survival analysis literatures [Cox1992, Faraggi and Simon1995, Luck et al.2017].
In the discrete context, a set of time slices is obtained which arises from the finite precision of time determinations. Analogously we may also consider the grouping of continuous time as and uniformly divide disjoint intervals where and is the last observation interval boundary for the given sample, i.e., the tracked observation time in the logs. is the last time interval in the whole data space. This setting is appropriately suited in our task and has been widely used in clinical research [Li et al.2016, Lee et al.2018], information systems [Jing and Smola2017, Grob et al.2018] and other related fields [Wu, Yeh, and Chen2015, Wang et al.2016].
As such, our event rate function and survival rate function over discrete time space is
where the input to the two functions is the observed time from the log. And the discrete event time probability function at the -th time interval is
The discrete conditional hazard rate , defined as the conditional probability as
which approximates the continuous conditional hazard rate function in Eq. (4) as the intervals become infinitesimal.
Till now, we have presented the discrete time model and discuss the death (i.e., event) and survival probability over the discrete time space. We here propose our DRSA model based on recurrent neural network with the parameter , which captures the sequential patterns for conditional probability at every time interval for the sample.
The detailed structure of DRSA network is illustrated in Figure 1. At each time interval , the -th RNN cell predicts the instant hazard rate given the sample feature and the current time conditioned upon the previous events as
where is the RNN function taking as input and as output. is the hidden vector calculated from the RNN cell at the last time step which contains the information about the conditional. It is quite natural for using the recurrent cell to model the conditional probability over time [Bahdanau, Cho, and Bengio2015]. In our paper we implement the RNN function as a standard LSTM unit [Hochreiter and Schmidhuber1997], which has been widely used in sequence data modeling. The details of the implementation of our RNN architecture can be referred in our supplemental materials and our reproductive code published in the experiment part.
where is the time interval index for the sample at .
By means of probability chain rule, it connects all the outputs of the conditional hazard rate at each individual time to the final prediction, i.e., the probability of the true event time and the survival rate at each time
. This feed-forward calculation guarantees that the gradient signal from the loss function can be transmitted through back-propagation more effectively comparing with[Lee et al.2018], which will be discussed below.
Since there is no ground truth of either the event time distribution or survival rate, here we maximize the log-likelihood over the empirical data distribution to learn our deep model. Specifically, we take three objectives as our losses.
The first loss is to minimize the negative log-likelihood of the true event time over the uncensored logs as
where is the index of the interval of the true event time .
The second loss is to minimize the negative partial log-likelihood of the event rate over the uncensored logs as
This loss adds more supervisions onto the predictions over the time range for the uncensored data than those [Katzman et al.2018, Lee et al.2018, Tibshirani1997] merely supervise on the true event time .
Though the censored logs do not contain any information about the true event time, we would only know that the true event time is greater than our logged observing time then. Here we incorporate the partial log-likelihood embedded in the censored logs as the third loss to correct the learning bias of our model as
In this section, we unscramble some intrinsic properties of our deep survival model.
First of all, we analyze the model effectiveness of DRSA. In [Lee et al.2018], the proposed deep model directly predicts the event probability and combines to estimate the survival rate as while ignoring the sequential patterns. As a result, the gradient signal would only have effect on the prediction at time individually. On the contrary, from Figure 1, we can see that our DRSA model is obviously more effective since the supervision would be directly back-propagated through the chain rule calculation to all the units with strict mathematical derivation, which guarantees to transmit the gradient more efficiently and effectively. We also explicitly model the sequential patterns by conditional hazard rate prediction and we will illustrate the advantage of that in the experiments.
Then we take the view of censorship prediction of our methodology. As is known that there is a censoring status as an indicator of survival at the given time, for each sample as
In the tracking logs, each sample is uncensored where . While for the censored logs losing tracking at the observation time, the true event time is unknown but the tracker only has the idea that , thus .
Moreover, for the uncensored data, it is natural to “push down” the probability of survival . And for the censored data, it needs to “pull up” since we “observe event not occurred” at time . However, using only to supervise the prediction of at time in Eq. (12) is insufficient. So that we incorporate the two partial likelihood losses and in Eqs. (13) and (14).
which is the cross entropy loss for predicting the survival status at time given over all the data .
Combining all the objective functions and our goal is to minimize the negative log-likelihood over all the data samples including both uncensored and censored data as
where is the model parameter in Eq. (8
) and the hyperparametercontrols the loss value balance between them. Specifically, controls the magnitudes of the two losses at the same level to stabilize the model training.
We also analyze the model efficiency in the supplemental material of this paper and the time complexity of model inference is the same as the traditional RNN model which has proven practical efficiency in the industrial applications [Zhang et al.2014].
We evaluate our model with strong baselines in three real-world tasks. Moreover, we have published the implementation code for reproductive experiments111Reproductive code link: https://github.com/rk2900/drsa..
We evaluate all the compared models in three real-world tasks. We also published the processed full datasets222We have put sampled data in the published code. The three processed full datasets link: https://goo.gl/nUFND4..
is a dataset for tracking the patient clinic status [Knaus et al.1995]. Here the goal of survival analysis is to estimate the time till the event (death), and predict the probability of the event with waning effects of baseline physiologic variables over time.
is a user lifetime analysis dataset [Jing and Smola2017] that contains roughly 1,000 users with entire listening history from 2004 to 2009 on last.fm, a famous online music service. Here the event is the user visit to the music service and the goal is to predict the time elapsed from the last visit of one user to her next visit.
is a real-time bidding dataset in the computational advertising field [Ren et al.2018, Wang et al.2016]. In this scenario, the time is correspondent to the bid price of the bidder and the event is just winning of the auction. The feature contains the auction request information. Many researchers [Wu, Yeh, and Chen2015, Wang et al.2016] utilized survival analysis for unbiased winning probability estimation of a single auction while handling the losing (censored) logs without knowing the true winning price.
The statistics of the three datasets are provided in Table 1
. We split the CLINIC and MUSIC datasets to training and test sets with ratio of 4:1 and 6:1, respectively. For feature engineering, all the datasets have been one-hot encoded for both categorical and numerical features. The original BIDDING data have already been feature engineered and processed as training and test datasets. Note that, the true time of the event of all the testing data have been preserved for the performance evaluation. In these datasets, since all the time is integer value, we bucketize the discrete time interval as interval sizeand the maximal time interval number is equal to the largest integer time in each dataset. The discussion about various interval sizes has been included in the supplemental materials.
The first evaluation metric is thetime-dependent concordance index (C-index), which is the most common evaluation used in survival analysis [Li et al.2016, Luck et al.2017, Lee et al.2018] and reflects a measure of how well a model predicts the ordering of sample event times. That is, given the observing time , two samples with large event time and with small event time should be ordered as where is placed before . This evaluation is quite similar to the area under ROC curve (AUC) metric in the binary classification tasks [Wang, Li, and Reddy2017]. From the ranking view of event probability estimation at time , C-index assesses the ordering performance among all the uncensored and censored pairs at among the test data.
We also use average negative log probability (ANLP) to evaluate the regression performance among different forecasting models. ANLP is to assess the likelihood of the co-occurrence of the test sample with the corresponding true event time, which is correspondent to the event time likelihood loss in Eq. (12). Here we compute ANLP as
where is the learned time-to-event probability function of each model.
Finally, we conduct the significance test to verify the statistical significance of the performance improvement of our model w.r.t. the baseline models. Specifically, we deploy a MannWhitney U test [Mason and Graham2002]
under C-index metric, and a t-test[Bhattacharya and Habtzghi2002] under ANLP metric.
We compare our model with two traditional statistic methods and five machine learning methods including two deep learning models.
KM is Kaplan-Meier estimator [Kaplan and Meier1958] which is a statistic-based non-parametric method counting on the event probability at each time over the whole set of samples.
STM is a survival random tree [Wang et al.2016] model which splits the data into small segments using between-node heterogeneity and utilizes Kaplan-Meier estimator to derive the survival analysis for each segment.
MTLSA is the recently proposed multi-task learning with survival analysis model [Li et al.2016]. It transforms the original survival analysis problem into a series of binary classification subproblems, and uses a multi-task learning method to model the event probability at different time.
DeepSurv is a Cox proportional hazard model with deep neural network [Katzman et al.2018] for feature extraction upon the sample covariates.
DeepHit is a deep neural network model [Lee et al.2018] which predicts the probability of event over the whole time space with the input . This method achieved state-of-the-art performance in survival analysis problem.
DRSA is our proposed model which has been described above. The implementation details can be referred to supplemental materials and the published code.
We present the evaluation results according to the category which the compared models belong to. KM and Lasso-Cox model are statistic-based methods, while Gamma, STM and MTLSA are machine learning based models. The rest models are deep neural network models with end-to-end learning paradigm.
The left part of Table 2 has illustrated the performance of the event rate estimation, i.e., C-index metric. From the table, we may observe the following facts. (i) Deep learning models illustrated relatively better C-index performance, which may be caused by the higher model capacity for feature extraction. (ii) Not only within the deep learning models, but also over all the compared methods, our DRSA model achieved the best C-index scores with significant improvements on all the datasets, which proves the effectiveness of our model. (iii) The traditional statistical models, i.e., KM and Lasso-Cox provided stable performance over the three datasets. (iv) The models with pre-assumptions about the event time distribution, i.e., Gamma and DeepSurv, did not perform well because the strong assumptions lack generalization in real-world tasks. (v) Within the deep models, no assumption about the latent distribution of time-to-event data makes DeepHit and DRSA flexibly model the data and perform better.
ANLP is a metric to measure the regression performance on the true event time prediction, i.e., the forecasting of the likelihood of the true event time.
From the right part of Table 2, it also shows the similar findings to the C-index results discussed above, e.g., our DRSA model has the best performance among all the methods. Moreover, STM segmented the data well so it achieved relatively better performance than other normal machine learning methods. With effective sequential pattern mining over time, our DRSA model performed relatively better than other deep models. Note that DeepHit directly predicts the probability of time-to-event while our modeling method is based on hazard rate prediction and optimize through probability chain rule. The results reflect the advantage of the sequential pattern mining with the novel modeling perspective of our model.
To illustrate the model training and convergence of DRSA model, we plot the learning curves and the C-index results on CLINIC and MUSIC datasets in Figure 2. Recall that our model optimizes over two loss functions, i.e., the ANLP loss and the cross entropy loss . From the figure, we may find that DRSA converges quickly and the values of both loss function drop to stable convergence at about the first complete iteration over the whole training dataset. Moreover, the two losses are alternatively optimizing and facilitate each other during the training, which proves the learning stability of our model.
Figure 3 illustrates the estimated survival rate curve over time and the forecasted event time probability for an arbitrarily selected test sample . Note that the KM model makes the same prediction for all the samples in the dataset, which is not personalized well. Our DRSA model accurately placed the highest probability on the true event time , which explains the result of ANLP metric where DRSA achieved the best ANLP scores. Since DeepHit directly predicted the probability of the event time without any considerations of the previous conditional. And it has no supervision onto the predictions in the time range which makes the gradient signal too sparse only onto the true event time . As a result, DeepHit did not place the probability well over the whole time space.
In this ablation study, we compare the model performance on the three losses. DRSA optimizes under over only the uncensored data, and DRSA optimizes under without the loss . Note that our full model DRSA optimizes under all the three losses as stated in Eq. (17). From Table 3, we may find that both two partial likelihood losses and contribute to the final prediction. Moreover, our DRSA over all the three losses achieved the best performance, which reflects the effectiveness of our classification loss as that in Eq. (16), which optimizes the C-index metric directly.
In this paper, we comprehensively surveyed the survival analysis works from the modeling view and discussed the pros and cons of them. To make flexibly modeling over time, we proposed a deep recurrent neural network with novel modeling view for conditional hazard rate prediction. And probability chain rule connects the predicted hazard rate at each time, for the event time probability forecasting and survival rate estimation. The experiments on three large-scale datasets in three real-world tasks from different fields illustrated the significant advantages of our model against the strong baselines including state-of-the-art model.
The corresponding authors Weinan Zhang and Yong Yu thank the support of National Natural Science Foundation of China (61632017, 61702327, 61772333), Shanghai Sailing Program (17YF1428200).