Survival modeling is key to precision oncology wherein cancer management, and treatment planning is personalized to a patient’s clinical, pathological, demographical, and genomic characteristics. Aided by the digitization of medical records, several studies over the past four decades have collected survival data based on longitudinal follow-up for various patient cohorts. Modeling survival in a cohort based on covariates known at the time of prediction is a complex task because the covariates to be taken into account can be large in number. These covariates can be entangled with each other by their interdependencies and interactions. A good survival model should give both (a) accurate survival estimates, and (b) a well-calibrated measure of uncertainty. We address the second problem in this work, which has escaped the attention it deserves.
Survival models have evolved primarily for accuracy over the last few decades. Cox proportional hazards model (Cox-PH) proposed by Cox (1972) is one of the oldest and most popular statistical models to predict survival. Cox-PH uses a hazard function to model the survival in a cohort and assumes that a patient’s relative log-risk of treatment failure (disease recurrence or death) at any time is a linear combination of the patient’s covariates that scales the underlying hazard function, which is another restrictive assumption. Multi-task logistic regression (MTLR) was proposed by Yu et al. (2011) as a remedy to the assumption of temporal constancy of relative risk between two patients, which led to increased prediction accuracy over Cox-PH. MTLR uses multi-task regression and joint likelihood minimization to model log-risk in a given time interval as a linear combination of the covariates. Recently, neural-MTLR was proposed by Fotso (2018)
to move away from the linearity assumption as well to increase the prediction accuracy. Neural-MTLR models nonlinear interactions among covariates as features extracted by the lower layers of a neural network, whose last layers are the same as that of an MTLR model.
Most of the survival models aim for high accuracy on test cohorts that are very similar to their training cohorts. Consequently, even when their prediction is off on the test inputs or the test inputs are outliers with respect to the training input data distribution, they may predict (wrongly) with high confidence. For example, the above mentioned survival models are unable to access per-patient uncertainty in survival predictions. Uncertainty calibration is important if survival prediction models are to be deployed in clinical settings. The prediction of any model is usually untrustworthy when the test data from a new patient is out of the training distribution (OOD). In such OOD cases, it is important to involve human experts, and hence it is important to identify such cases with the model rightfully expressing high uncertainty or low confidence. Bayesian neural networks (BNNs) provide a framework to capture the underlying uncertainties inherent to both the data (data uncertainty) and the limitations of the model (model uncertainty). We propose a Bayesian extension of (neural) MTLR that can capture patient-specific survival uncertainties. Capturing uncertainty in the prediction also helps us handle heterogeneous data and analyze prognostically important covariates. Furthermore, we incorporate a prior in our model to sparsify a large number of input covariates.
2 Theoretical Background and Proposed Method
2.1 Survival models
The setting that we assume has a set of covariates associated with each patient , a time to adverse event () (usually death or disease recurrence), and an event indicator (). The event indicator means that the patient died after a time interval of . Patients with are called right-censored, indicating that she was surviving (or living disease-free) at time , but survival beyond that time in unknown.
The survival function and hazard function are two important outcomes of survival models. The survival function,
, is the probability of a patient to survive more than time t. The hazard functionis given by , which means the probability that an individual will not survive an extra infinitesimal amount of time , given they have already survived up to time t. Cox-PH (Cox, 1972) models the hazard function at time
for a given vector input covariatesin terms of an underlying hazard function and linear weights for the covariates as follows: .
MTLR (Yu et al., 2011) assumes a series of logistic regression models for time intervals, where is chosen based on the desired fineness of temporal variation and the size of the training data, as follows: . The parameters and depend on the time interval , whereas the input vector is same for all regression models. However, the outputs of these logistic regression models are not independent, because a death event at time would mean a death event at all subsequent time points . We encode the output of the regression model, using a -dimensional binary sequence , where, means that the patient is living at time and means that the patient is dead at time . Thus, once we encounter a , all subsequent are bound to be 1. A smoothness prior on the parameters across time ensures that the predictions are not noisy. The probability of observing a sequence is the likelihood of the model. It can be generalized by the logistic regression model as follows: , where .
The MTLR loss function for uncensored patients is obtained by taking the logarithm of the joint likelihood term and adding regularization terms for temporal smoothness of the parameters and the resultant predictions, as follows:
are hyperparameters that control the amount of smoothing in the parameters andis the number of patients.
For right-censored patients (those who are lost to follow-up), there are more than one consistent binary sequences of ’s. In this case the likelihood of the patient is the sum of likelihoods of all possible sequences. The overall likelihood for censored patients whose last contact was closest to time point is given as follows: .
2.2 Variational Inference
A feed-forward neural network trained with gradient descent will arrive at point estimates. However, in the case of Bayesian NNs (BNNs)(Neal, 2012)
, the weights are not point estimates but a parameterized probability distribution. Our task is to find a distribution over the parameters given the input data, i.e.,. With this posterior, we can predict test output for a new test input by marginalizing the likelihood over the parameters . However, even for the modest-sized NNs, the number of parameters prohibits an analytical calculation of uncertainty, and one has to resort to approximate inference methods. We define an approximating variational distribution with parameters
. Then the Kullback-Leibler divergence (KL) with respect to the parametersis minimized between the proposed posterior and the true posterior.
Minimizing the KL divergence is equivalent to minimizing the variational free energy (Friston et al., 2007), (Blundell et al., 2015), where the latter is often computed on mini batches for computational tractability. We then estimate the cost using an unbiased Monte Carlo (MC) approximation for each mini batch as follows:
2.3 Proposed probabilistic weights to model uncertainty
We assume the posterior and the prior on weights to be a spike and slab, which is standard for sparse linear models (Mitchell and Beauchamp, 1988) (George and McCulloch, 1993) (Titsias and Lázaro-Gredilla, 2011). Recently, a closed form expression for the KL divergence between the spike and slab posterior and spike and slab prior was derived by Tonolini et al. (2019)
, which we utilized in this work. The prior probability density is given as follows:, where is the dirac delta function centered at zero. The sparsity of solution can be increased for this prior by decreasing from one to zero. The posterior is chosen to be of similar form, given as: , where , and
are the free parameters of the neural network. The choice of posterior not only allows us to derive an analytical lower bound for the KL divergence between assumed posterior and prior but also gives additional degree of freedom compared to a fully factorized Gaussian.
In order to quantify data uncertainty, we use the standard trick of predicting not only mean but also the variance of survival probability(Kendall and Gal, 2017). Our overall prediction now becomes a sample drawn from this Gaussian, as follows: , where and are approximated using Monte Carlo samples. The loss function for the mini batch of our Bayesian variant is given as follows:
where is the negative log-likelihood defined as in equation 1, is the number of mini-batches and is the total number of parameters. One can see that setting and
reduces this expression to a fully factorized Gaussian posterior and prior that is used in varational autoencoders(Kingma and Welling, 2013).
We used a simple one-hidden layer Bayesian neural network with spike and slab prior and posterior, and ReLU activation in all but the final layer. The number of inputs to the network are equal to the number of covariates, and the number of outputs equal to number of time intervals. Instead of a fully connected structure from the first layer to the hidden layer, we only assume a one-to-one mapping to simulate variable elimination based on the sparsity inducing prior, as shown in Figure1.
Using a subset of 47 out of the PAM50 gene expressions and clinical variables that were common to both TCGA-BRCA (15) and METABRIC (Curtis et al., 2012) datasets we trained on one dataset and tested on the other to obtain results on model accuracy. We combined both datasets and held out samples at random for experiments on variable importance and uncertainty estimation.
3.1 Survival predictions
C-index and Integrated Brier Score (IBS) are two commonly used metrics for analyzing the accuracy of survival models for censored data, where the former is a generalization of the area under the ROC curve (AUC), and the latter is the average weighted squared distance between the observed and predicted survival. Thus, a higher C-index and lower IBS implies a more accurate model. Table 1 shows our method performs better compared to Cox-PH, MTLR, and a comparable neural-MTLR model with a single hidden layer.
|Cox-PH||0.65 0.10||0.20 0.07|
|MTLR||0.68 0.06||0.21 0.06|
|N-MTLR||0.68 0.02||0.16 0.04|
|Our Method||0.71 0.05||0.12 0.02|
3.2 Ranking prognostic features
We obtained feature importance for each input feature based on the distribution of weights learned by the network from the first layer to the hidden layer. We interpreted the ratio of mean and standard deviation of the weight associated with a feature as its signal to noise ratio. In case of spike and slab posterior, the signal to noise ratio for feature i is given by:. We observe in Figure 2 that age at diagnosis, lymph node metastasis, and tumor stage are among the top three prognostically important features. Among the genomic signatures, BCL2 is an antiapoptotic protein whose prognostic role in breast cancer is consider as sub-type specific. It is a good prognostic marker mostly for Luminal A breast cancers (Corces et al., 2018). CDC20 is an oncoprotien that promotes the development and progression of breast cancer, and its overexpression is often associated with poor short-term survival, specifically in triple-negative breast cancers (Karra et al., 2014). CDC25L (a.k.a RASGRF1) plays a key role in tumor cell proliferation and and inflammation through mitogen-activated protein kinase (MAPK) pathway in breast cancers (Rodrigues-Ferreira et al., 2012). Similarly, PTTG1 gene enhances the migratory and invasive properties of breast cancer cells by inducing epithelial to mesenchymal transition (Tetreault et al., 2013). Some of these genes have not found a regular place in risk assessment assays, whereas our method may be able to systematically suggest their direct association with survival.
3.3 Low confidence on out-of-distribution (OOD) test data
In order to demonstrate the use of quantifying uncertainty, we divided the entire data (TCGA + METABRIC) into old (age 60 years) and young patients (age 60 years), where 60 years is the median age of patients in the dataset. We trained the model on 80% of the old patients and tested it on the remaining 20% old as well as all of the young patients. We define mean uncertainty score associated with a survival prediction as the mean of the standard deviations in model predictions (for 50 forward passes) across all time points. The test cohort of younger patients was successfully identified as OOD as their mean uncertainty was 110% higher than that of the test cohort of older patients, as shown in Figure 3). Similarly, we trained another model on a subset of patients with low cancer stage and saw a 43% higher mean uncertainty score for higher-stage patients (OOD) as compared to the held-out lower stage patients.
We proposed a Bayesian framework for modeling survival prediction that not only gives more accurate predictions but is also able to select prognostically important features in the data and detect test samples that are out of the training distribution. This makes our model more interpretable and trustworthy due to its well-calibrated uncertainty estimates. Our approach is a step in the direction of training models that go beyond a singular focus on test prediction accuracy to that of recognizing uncertainty appropriately in new cohorts and producing new biological insights. Such approaches should be further tested in larger multi-institutional and multi-cohort settings.
- Weight uncertainty in neural networks. arXiv preprint arXiv:1505.05424. Cited by: §2.2.
- The chromatin accessibility landscape of primary human cancers. Science 362 (6413), pp. eaav1898. Cited by: §3.2.
- Regression models and life-tables. Journal of the Royal Statistical Society. Series B (Methodological) 34 (2), pp. 187–220. External Links: Cited by: §1, §2.1.
- The genomic and transcriptomic architecture of 2,000 breast tumours reveals novel subgroups. Nature 486 (7403), pp. 346–352. Cited by: §3.
- Deep neural networks for survival analysis based on a multi-task framework. External Links: Cited by: §1, §2.1.
- Variational free energy and the laplace approximation. Neuroimage 34 (1), pp. 220–234. Cited by: §2.2.
- Variable selection via gibbs sampling. Journal of the American Statistical Association 88 (423), pp. 881–889. External Links: Cited by: §2.3.
- Cdc20 and securin overexpression predict short-term breast cancer survival. British journal of cancer 110 (12), pp. 2905–2913. Cited by: §3.2.
- . External Links: Cited by: §2.3.
- Auto-encoding variational bayes. External Links: Cited by: §2.3.
Bayesian variable selection in linear regression. Journal of the American Statistical Association 83 (404), pp. 1023–1032. External Links: Cited by: §2.3.
- Bayesian learning for neural networks. Vol. 118, Springer Science & Business Media. Cited by: §2.2.
- Angiotensin ii facilitates breast cancer cell migration and metastasis. PloS one 7 (4). Cited by: §3.2.
- Krüppel-like factors in cancer. Nature Reviews Cancer 13 (10), pp. 701. Cited by: §3.2.
-  The Cancer Genome Atlas. Note: https://www.cancer.gov/tcgaAccessed: 2020-03-18 Cited by: §3.
- Spike and slab variational inference for multi-task and multiple kernel learning. In Advances in Neural Information Processing Systems 24, J. Shawe-Taylor, R. S. Zemel, P. L. Bartlett, F. Pereira, and K. Q. Weinberger (Eds.), pp. 2339–2347. Cited by: §2.3.
Variational sparse coding.
Uncertainty in Artificial Intelligence. Cited by: §2.3.
- Learning patient-specific cancer survival distributions as a sequence of dependent regressors. In Advances in Neural Information Processing Systems 24, J. Shawe-Taylor, R. S. Zemel, P. L. Bartlett, F. Pereira, and K. Q. Weinberger (Eds.), pp. 1845–1853. Cited by: §1, §2.1.