Countdown Regression: Sharp and Calibrated Survival Predictions

06/21/2018 ∙ by Anand Avati, et al. ∙ 2

Personalized probabilistic forecasts of time to event (such as mortality) can be crucial in decision making, especially in the clinical setting. Inspired by ideas from the meteorology literature, we approach this problem through the paradigm of maximizing sharpness of prediction distributions, subject to calibration. In regression problems, it has been shown that optimizing the continuous ranked probability score (CRPS) instead of maximum likelihood leads to sharper prediction distributions while maintaining calibration. We introduce the Survival-CRPS, a generalization of the CRPS to the time to event setting, and present right-censored and interval-censored variants. To holistically evaluate the quality of predicted distributions over time to event, we present the Survival-AUPRC evaluation metric, an analog to area under the precision-recall curve. We apply these ideas by building a recurrent neural network for mortality prediction, using an Electronic Health Record dataset covering millions of patients. We demonstrate significant benefits in models trained by the Survival-CRPS objective instead of maximum likelihood.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Having patient-specific predictions of time to an event such as mortality or bone fracture allows caregivers to make better informed decisions around patient care. Historically, prognosis scores have served as simple tools to stratify patient risk within a predefined time window db9910e0f7a64a75866d354a536fece3 ; PMID:25613983

. However, such models tend to be too simplistic to be widely useful. They are often estimated from a large population of patients, and do not take into account patient-specific information to make individualized predictions

yu_learning_2011

. Meanwhile, the adoption of Electronic Health Record (EHR) systems over the past few decades has resulted in the collection of observational data on millions of patients spanning multiple years. This data enables development of patient-specific prediction models using machine learning. Such models are applicable to the larger patient population without being specific to a disease type or demographic, and this makes it possible to develop novel workflows in care delivery. For example, a high predicted probability of 3-12 month mortality could proactively notify palliative care teams of otherwise overlooked patients with end-of-life needs

avati_improving_2017 .

One way to obtain patient-specific survival predictions is to treat the problem as probabilistic classification; that is, training a binary classifier to predict outcomes of event by a particular time of interest

avati_improving_2017 ; rajkomar_scalable_2018 . However, such an approach has drawbacks. First, the model is specific to the time of interest it was trained upon – it is not straightforward how to take a model that was trained to predict probabilities of 1-year mortality and obtain predictions of 6-month mortality from it. Second, it is not usually possible to use data on all patients – for example, if a patient has only 3 months of history in the EHR system, it is neither possible to include that patient as a positive case nor a negative case in the 1-year mortality prediction task. Third, the process of constructing the data set implicitly conditions on the future outcome to select prediction times – evaluation is performed only at times looking backward from the event of interest. It has been shown that evaluation metrics can be overly optimistic relative to real world performance as a result sherman_leveraging_2017 .

Figure 1: Example of a patient’s predicted distributions for age of death under different models. Our proposed techniques improve sharpness of predicted distributions, subject to calibration. Repeated interactions (indicated by darker color) between the patient and the EHR yield more confident predictions of time of death.

An alternative approach to the problem is survival prediction; that is, predicting time to event by estimating a distribution over future time. In this setting, traditional survival analysis methods such as the Cox proportional hazards model Cox1972 or accelerated failure time models aft are capable of handling data with censored observations (cases in which the event was not observed, but we know that the event did not occur up to a certain time). This addresses concerns raised by the classification approach, but there are a few nuances. First, traditional models typically make strong assumptions, such as proportional hazards or linearity. Second, challenges of low prevalence often arise when these methods are applied to large-scale observational datasets with heavy censoring, which is the case in real EHR data. Third, these survival analysis methods are typically evaluated as point estimates of risk, such as 10-year probabilities of events, rather than holistic measures of quality of the predicted distributions goff_2013_2014 ; ranganath_deep_2016 ; lee_deephit:_2018 . Common metrics of evaluation include the C-statistic uno_evaluating_2007 , log- loss yu_learning_2011 , and mean-squared-error katzman_deepsurv:_2018

. While useful for the purposes of relative risk stratification, model comparisons made using point estimates leaves the quality of uncertainty in predicted distributions left unmeasured. If a point prediction is way off, it is penalized by the same amount whether the model was confident or not (that is, whether the predicted distribution had low or high variance).

In contrast, forecasts in the field of meteorology are typically made as full prediction distributions over all weather conditions given past and current observations Gneiting2008 . Evaluation of predictive performance is assessed by the paradigm of maximizing the sharpness of the predictive distribution, subject to calibration gneiting_probabilistic_2014 . The intuition behind this paradigm is that probabilities have to be calibrated in order to be correct. However, that does not necessarily make them useful (one could always predict the marginal probability of an outcome without looking at the data, and still be well calibrated). The usefulness of a prediction distribution lies in its sharpness, or how well its mass concentrates. In summary, uncalibrated predictions (sharp or not) are useless, calibrated but non-sharp predictions are correct but less useful, and calibrated and sharp distributions are most useful.

To improve the sharpness of prediction distributions in the survival setting, we propose the use of proper scoring rules beyond maximum likelihood as the training objective. Proper scoring rules are known to measure calibration, and any model trained with a proper scoring rule will tend to maintain calibration gneiting_probabilistic_2014 . For our purposes, we focus on the continuous ranked probablity score (CRPS) which has been used as an objective to improve sharpness in the regression setting Gneiting2008 ; mohammadi_meta-heuristic_2016 ; mohammadi_optimization_2015 . We generalize the CRPS for the survival setting, called Survival-CRPS, with right-censored and interval-censored extensions. To our knowledge this is the first time any scoring rule other than maximum likelihood has been successfully applied to a large-scale survival prediction task.

Summary of contributions. Our main contributions are as follows. (1) We introduce the proper scoring rule Survival-CRPS, a generalization of CRPS, as an objective in survival prediction. We present its right-censored and interval-censored variants. (2) We propose a new metric, Survival-AUPRC, inspired by the paradigm of maximizing sharpness subject to calibration, to holistically measure the quality of a prediction distribution with respect to a possibly censored outcome. (3) We give practical recommendations for the mortality prediction task, by recommending use of the log-normal parameterization and interval censoring when training. (4) We employ the above techniques and demonstrate their efficacy by training a deep recurrent neural network model for accurate survival prediction of patient mortality using EHR data.

2 Countdown Regression

Parametric survival prediction methods model the time to an event of interest with a family of probability distributions, uniquely identified by the distribution parameters. The survival function, denoted

, is a monotonically decreasing function over the positive reals with and . The survival function represents the probability of an individual not having the event of interest up to a given time. Every survival function has a corresponding cumulative density function (CDF), denoted

, and probability density function (PDF), denoted

. The choice of the family of probability distributions implies assumptions made about the nature of the data generating process.

We denote the medical record of a patient as , where denotes the interaction number of this patient with the health record, is the set of features corresponding to the -th interaction, is age at time , is the age of death or age of last known (alive) encounter, and is a censoring indicator where means the age of death is , and means the age of death is at least . For each we define the quantity which represents the corresponding time to event or time to censoring.

Traditional methods in survival analysis are designed to handle right-censored outcomes, but we observe that in many common scenarios outcomes are actually interval-censored. In the context of mortality prediction, for example, we know that humans almost never live past 120 years of age. Therefore, we assume that the true age of death lies within and years, implying that the true time to death lies between 0 and . We omit patient superscripts and interaction subscripts for succinctness where possible. We note that although our notation focuses on the problem of mortality prediction, our techniques generalize to any time to event task of interest.

2.1 Survival-CRPS: proper scoring rules as training objectives

A scoring rule is a measure of the quality of a probabilistic forecast. A forecast over a continuous outcome is a probability density function over all possible outcomes, with corresponding cumulative density function . In reality, we observe some actual outcome, . A scoring rule takes a predicted distribution and an actual outcome, and returns a loss . It is considered a proper scoring rule if for all possible distributions ,

and strictly proper when equality holds if and only if Gneiting2008 . A proper scoring rule is one in which the expected score is minimized by the distribution with respect to which the expectation is taken. Intuitively, it encourages a model for being honest by predicting what it actually believes SavageElicitation

. When a proper scoring rule is employed as a loss function, it naturally forces the model to output calibrated probabilities

gneiting_probabilistic_2014 .

There are many commonly used proper scoring rules. Perhaps the most widely used is the logarithmic scoring rule, equivalent to the maximum likelihood objective:

In the presence of possibly censored data, we maximize the density for observed outcomes, and tail or interval mass for censored outcomes, and this is a proper scoring rule dawid_theory_2014 .

However, the logarithmic scoring rule is asymmetric, and harshly penalizes predictions that are wrong yet confident. This results in the training process becoming sensitive to outliers, and in general conservative in prediction-making (that is, hesitant to make sharp predictions)

StrictlyProper .

Another proper scoring rule for forecasts over continuous outcomes is the CRPS ProbForeCalibSharp , defined as

The CRPS has been used in regression as an objective function that yields sharper predicted distributions compared to maximum likelihood, while maintaining calibration Gneiting2008 . Intuition for the CRPS is better understood by analyzing the latter expression and noting that the two integral terms correspond to the two shaded regions in Figure 2fig:crps-a. The CRPS score is completely reduced to zero when the predicted distribution places all the mass on the point of true outcome, or equivalently, when the shaded region completely vanishes.

In the context of time to event predictions we propose the Survival-CRPS which accounts for the possibility of right-censored or interval-censored data.

(a) Uncensored
(b) Right-censored
(c) Interval-censored
Figure 2: Graphical intuition for the Survival-CRPS scoring rule. For uncensored observations, we minimize mass before and after the observed time of event. For right-censored observations, we minimize mass before observed time of censoring. For interval-censored observations, we minimize mass before observed time of censoring, and mass after the time by which event must have occurred.

Note that when , both of the above expressions are equivalent to the original CRPS. Again, the intuition behind the Survival-CRPS is better understood by analyzing the second expressions (without the indicator terms) and mapping each of the integrals to the corresponding shaded region in Figure 2fig:crps-b and Figure 2fig:crps-c. The Survival-CRPS behaves like the original CRPS when the time of event is uncensored. For censored outcomes, it penalizes the predicted mass that occurs before the time of censoring and, if interval censored, also the mass after time by which the event must have occurred.

Both variants of the Survival-CRPS are proper scoring rules. They are special cases of the threshold weighted CRPS Ranjan_2011 , where the weighting function is an indicator over the uncensored regions.

2.2 Evaluation by sharpness subject to calibration

Calibration

assesses how well forecasted event probabilities match up to observed event probabilities. It is crucial in development of useful predictive models, especially for clinical decision-making. In binary prediction tasks without censoring, the Hosmer-Lemeshow test statistic

hosmer_lemeshow

is commonly used to assess goodness-of-fit by comparing observed versus predicted event probabilities at quantiles of predicted probabilities. Extensions to account for censoring have been proposed

Gronnesby1996 ; dagostnio ; demler_tests_2015 , but these methods apply only to predictions of dichotomous outcomes within a particular time frame (for example, 1-year risks of mortality).

There is no widely accepted method for evaluating the calibration of a set of entire prediction distributions, over multiple time frames, in the survival setting. D-calibration has been recently proposed as a method for holistic evaluation andres_novel_2018

, but relies on handling censored observations by assuming the true times to death are uniformly distributed past the times of censoring in the predicted distributions. When censored observations far outnumber the uncensored observations, this can lead to overly optimistic assessments of calibration. Another option is to evaluate observed event times on the cumulative density scale of predicted distributions, using a Kaplan-Meier estimate to account for censoring

harrell_rms . Again, this method has limitations in the heavily censored setting, as the quantiles in the tail of predicted cumulative densities have few uncensored observations, and will rarely yield well calibrated values.

We instead employ the following method to measure calibration. We compare predicted cumulative densities against observed event frequencies, evaluated at quantiles of predicted cumulative density. Right-censored observations are removed from consideration in quantiles that correspond to times after their points of censoring. Interval-censored observations are similarly removed from consideration in quantiles that correspond to times after censoring, but are additionally re-introduced in quantiles that correspond to times past the time by which the event must have occurred (in the mortality prediction task, this corresponds to 120 years of age).

Subject to calibration, we strive for prediction distributions that are sharp

(i.e, concentrated). There are several metrics that could be used for measuring sharpness, such as variance or entropy. In the context of time to event predictions, holding two distributions with vastly different means to the same standard of variance or entropy would be unfair (for example, we would want lower variance for a prediction distribution with a mean of a day, compared to a mean of a year). Instead, we use the coefficient of variation (CoV) as a reasonable measure of sharpness. The CoV is defined as the ratio of one standard deviation to the mean,

2.3 Survival-AUPRC: holistic evaluation of a time to event prediction distribution

Since sharpness is only a function of the predicted distributions, a measure of sharpness is only meaningful if the model is sufficiently calibrated. We now propose a metric that measures how concentrated the mass of the prediction distribution is around the true outcome, robust to miscalibration. The idea is similar to the area under a precision-recall curve, except here it is with respect to only one predicted distribution and one outcome. We first consider the uncensored case. As an analog to precision, we consider intervals relative to the true time of event, defined by ratios. For example, a region of precision 0.9 around an event that occurs at time is the interval . Corresponding to this region of precision, the analogy to recall is the mass assigned by the predicted distribution over this interval, . By exploring the full range of precision from 0 to 1, we obtain the Survival Precision Recall Curve. The area under this curve measures how quickly predicted mass concentrates around the true outcome as we expand the precision window.

The highest possible score is 1, when the predicted distribution is a Dirac function centered over the time of outcome. The lowest possible score is 0, when the predicted distribution is infinitely dispersed. The mean of all Survival-AUPRC scores across examples provides an overall measure of the quality of the predictions.

The aforementioned metric only applies when the event outcome is uncensored. In the case of censored observations, we use the same analogy but with the right end of precision intervals defined with respect to the time by which the event must have occurred in the interval-censored case, or infinity in the right-censored case.

2.4 Recurrent neural network model

We apply our techniques to the mortality prediction task by building a multilayer recurrent neural network (RNN) with parameters , denoted , that takes as input a sequence of features (in our case, information about a patient recorded in the EHR, for each interaction they had with the hospital) to predict parameters of a parametric probability distribution over time to death at each timestep. The network depends only on data from the current and previous timesteps, and not the future. The approach here is similar to the recently proposed Weibull time to event RNN martinsson_model_2016 , though we generalize to any choice of noise distribution. The distributions that are output in each timestep are used to construct an overall loss,

where is the total number of patients in the training set, is the sequence length for patient , and denotes the distribution parameterized by the output of the RNN. It is the sequential and monotonically decreasing predicted times to event that inspires the name Countdown Regression.

2.5 Choice of log-normal noise distribution

Common parametric distributions over time to event used in traditional survival analysis models include the Weibull, log-normal, log-logistic, and gamma (in order to be sufficiently expressive in model space, we seek distributions with at least two parameters). We choose the log-normal distribution because other distributions either involve the Beta function in their density, or involve the pattern

, where and are parameters output from the neural network. We found these patterns to be highly sensitive to the inputs and to suffer from numerical instability issues.

For the log-normal distribution, a closed form expression for the CRPS is well known Baran_2015 . However, a closed form expression for the Survival-CRPS does not exist. We perform a change of variable to express the integral terms as finite integrals, and numerically approximate with the trapezoid rule. When training, we then back-propagate through the trapezoidal approximation. Details are given in Appendix B and C. We note that the approximation formulas are also proper scoring rules by themselves, as they are just weighted sums of brier scores. Closed form expressions for the Survival-AUPRC are also given in Appendix D, E, and F.

3 Experiments

We run experiments for the mortality prediction task to evaluate four different training objectives: maximum likelihood and , and our scoring based loss and . For interval censoring we assume a maximum lifespan of years.

The neural network architecture is kept identical for all four experiments and implemented in PyTorch

paszke2017automatic

. The input at each timestep consists of both real valued (for example, age of patient) and discrete valued (for example, ICD codes) data. Discrete data is embedded into a trainable real-valued vector space, and vectors corresponding to the codes recorded at a given timestep are combined into a weighted mean by a soft self-attention mechanism. All real valued inputs are appended to the averaged embedding vector. We also provide the real valued features to every layer by appending them to the output of previous layer. The input vector feeds into a fully connected layer, followed by multiple recurrent layers. We use the Swish activation function

swish and layer normalization layernorm at every layer. Recurrent layers are defined using GRU units gru with layer normalization inside. After the set of recurrent layers, the network has multiple branches, one per parameter of the survival distribution (for the lognormal, and ). The final layer in each branch has scalar output, optionally enforced positive with the softplus function, . We use Bernoulli dropout dropout at all fully connected layers, and Variational RNN dropout variational-rnn in the recurrent layers, with a dropout probability of 0.5. Optimization is performed using the Adam optimizer kingma_ba , with a fixed learning rate of 1e-3.

3.1 Data

We use electronic health records, with IRB approval, from the STARR Data Warehouse (previously known as STRIDE) for training and evaluation

Lowe2009STRIDEPlatform . The Warehouse contains de-identified data for over 3 million patients (about 2.6% having a recorded date of death), spanning approximately 27 years. Each timestep in the sequence for a patient corresponds to all the data in the EHR for a given day. Only days having any data have a corresponding timestep in the sequence for each patient. We use diagnostic codes, medication order codes, lab test order codes, encounter type codes, and demographics (age and gender). Each code has a randomly initialized embedding vector as a trainable parameter. The set of 3 million patients, correspond to 51 million overall timesteps, and was randomly split in the ratio 8:1:1 into train, validation and test splits.

3.2 Results

Metric MLE-RIGHT MLE-INTVL CRPS-RIGHT CRPS-INTVL
Calibration slope 1.125 3e-4 1.139 3e-4 1.003 3e-4 0.959 5e-4
Mean coefficient of variation 18.42 5e-3 0.911 4e-4 0.332 1e-4 0.301 1e-4
Mean prob of survival to age 120 yrs 0.754 2e-5 0.045 3e-5 0.015 3e-5 0.005 1e-6
Dead: mean Surv-AUPRC (uncen) 0.233 2e-4 0.319 3e-4 0.343 4e-4 0.366 4e-4
Alive: mean Surv-AUPRC (intvl-cen) 0.407 6e-5 0.963 2e-5 0.977 3e-5 0.976 3e-5
Table 1: Metrics measuring sharpness and calibration for models trained on the right-censored and interval-censored variants of the maximum likelihood and Survival-CRPS objectives.
Figure 3: Calibration plots for each of the models. We compare predicted cumulative densities against observed event frequencies, evaluated at quantiles of predicted cumulative density. Right-censored observations are removed from consideration in quantiles past times of censoring, interval-censored observations are additionally re-introduced in quantiles corresponding to times past 120 years.

We first verify that all models are reasonably well-calibrated (Figure 3). Both the coefficient of variation and the Survival-AUPRC metrics suggest that the Survival-CRPS with interval censoring yields the sharpest prediction distributions (Table 1). Inspecting the mass past 120 years of age shows that a naively trained prediction model with maximum likelihood can assign more than 75% of the mass to unreasonable regions, which is highly undesirable for the purpose of prediction. We note that this behavior is largely due to low prevalence of uncensored examples, which is typical in real world EHR data sets. As a result, the loss for the censored examples, which can be minimized by pushing mass as far away to the right as possible, dominates the small number of uncensored examples.

By predicting an entire distribution over time to death, the same model can be used to make classification predictions at various time points, highlighting the flexibility of our approach. When evaluated at 6 month, 1 year, and 5 year probabilistic predictions of mortality, our model remains well-calibrated with high discriminative ability (Figure 4).

Figure 4: Discrimination and calibration of predictions from the interval-censored Survival-CRPS model, evaluated as predictions for a dichotomous outcome at 6 months, 1 year, and 5 years.
Figure 5: Median predicted time to death (with 95% intervals) for individual patients from the interval-censored Survival-CRPS model. Our model gives more confident predictions upon repeated interactions between patients and the EHR. True times to death generally lie within predicted intervals.

4 Related Work

Recent works have demonstrated potential to significantly improve patient care by making predictions with deep learning models on EHR data

avati_improving_2017 ; rajkomar_scalable_2018 , but these have been limited in treating the task as binary classification over a fixed time frame. Predicting survival curves instead of dichotomous outcomes has been explored yu_learning_2011 ; lee_deephit:_2018 , but only over finite length horizons. Deep survival analysis ranganath_deep_2016 has been proposed, but is limited to a fixed shape Weibull (bypassing the concerns we raised about stability, but limited in expressivity). DeepSurv katzman_deepsurv:_2018 uses a Cox proportional hazards model, which similarly makes a set of inflexible assumptions. The WTTE-RNN martinsson_model_2016 model has a similar network architecture to ours, but is also limited to a Weibull distribution. All aforementioned models have only been optimized for maximum likelihood, instead of more robust proper scoring rules. Work in NIPS2017_DMGP also predicts full survival curves specific to a patient, but the use of GPs makes it difficult to scale to millions of patients.

5 Conclusion

Better survival prediction models can be built by exploring objectives beyond maximum likelihood and evaluation metrics that assess the holistic quality of predicted distributions, instead of point estimates. We introduce the Survival-CRPS objective, motivated by the fact that the CRPS scoring rule is known to yield sharp prediction distributions while maintaining calibration. There are perhaps others scoring rules that work better, leaving avenues for future work. To evaluate, we introduce the Survival-AUPRC metric, which captures the degree to which a prediction distribution concentrates around the observed time of event. We demonstrate success in large-scale survival prediction by using a deep recurrent model employing a log-normal parameterization. By predicting an entire distribution for time-to-event, we circumvent issues associated with binary classification. Meanwhile, our model still yields accurate predictions when evaluated as dichotomous outcomes at particular times. The impact of having meaningfully accurate survival models is tremendous, especially in healthcare. We hope our work will be useful to those looking to build and deploy such models.

Acknowledgments

We thank PyTorch team, particularly for the erf implementation that allowed use of the log-normal distribution. We thank Baran Sandor, Sebastian Lerch, Alejandro Schuler, Jeremy Irvin, and Russell Greiner for valuable feedback.

References

Appendix

A. Integral Identities

Let

be the CDF of a Gaussian distribution with mean

and variance . Hence is the CDF of a log-normal distribution with mean and variance . For some integer (typically 32 in our experiments), we define to be the following integral, approximated by the trapezoidal rule:

where and is a function. We further define

B. Survival-CRPS for log-normal (right-censored)

For a general continuous prediction distribution , with actual time to outcome , and censoring indicator , we generalize the CRPS to the Right Censored Survival CRPS score as:

In the above expression would generally be in the family of continuous distributions over the entire real line (eg. Gaussian). Alternately, one could also use a family of distributions over the positive reals (e.g log-normal), in which case the Survival CRPS becomes:

For the case of being log-normal, the expression becomes

C. Survival-CRPS for log-normal (interval-censored)

We further extend the Right Censored Survival CRPS to the case of interval censoring. This is particularly useful for all-cause mortality prediction where we assume a particular event must occur by time . Using the same notations as before, the Interval Censored Survival CRPS is:

For the case of being log-normal, the expression becomes

D. Survival-AUPRC for log-normal (interval-censored)

We start with the most general case (interval censoring). For a general continuous prediction distribution with an interval outcome , we define the Survival-AUPRC as

Specifically for the case of log-normal, where and are PDF and CDF of respectively, and and :

E. Survival-AUPRC for log-normal (right-censored)

For a general continuous prediction distribution with an interval outcome , we define Survival-AUPRC as

Specifically for the case of log-normal, where is the CDF of , and (following Appendix-D),

F. Survival-AUPRC for log-normal (uncensored)

For a general continuous prediction distribution with a point outcome , we define Survival-AUPRC

Specifically for the case of log-normal, where is the CDF of , and (following Appendix-D),