horseshoe-bnn
A Bayesian Neural Network with a horseshoe prior for improved interpretability
view repo
Clinical decision making is challenging because of pathological complexity, as well as large amounts of heterogeneous data generated as part of routine clinical care. In recent years, machine learning tools have been developed to aid this process. Intensive care unit (ICU) admissions represent the most data dense and time-critical patient care episodes. In this context, prediction models may help clinicians determine which patients are most at risk and prioritize care. However, flexible tools such as artificial neural networks (ANNs) suffer from a lack of interpretability limiting their acceptability to clinicians. In this work, we propose a novel interpretable Bayesian neural network architecture which offers both the flexibility of ANNs and interpretability in terms of feature selection. In particular, we employ a sparsity inducing prior distribution in a tied manner to learn which features are important for outcome prediction. We evaluate our approach on the task of mortality prediction using two real-world ICU cohorts. In collaboration with clinicians we found that, in addition to the predicted outcome results, our approach can provide novel insights into the importance of different clinical measurements. This suggests that our model can support medical experts in their decision making process.
READ FULL TEXT VIEW PDFA Bayesian Neural Network with a horseshoe prior for improved interpretability
None
Clinicians often need to make critical decisions, for example, about treatments or patient scheduling, based on available data and personal expertise. The increasing prevalence of electronic health records (EHRs) means that routinely collected medical data are increasingly available in machine readable form. In some areas such as the intensive care unit (ICU), the data density may be so large that it becomes difficult for clinicians to fully appreciate relationships and patterns in clinical records. At the same time, ICU patients are the most severely ill. Life-supporting treatments are not only expensive and limited but may be associated with potentially catastrophic side effects. It is in this context that appropriate treatments must be delivered in a time-critical manner based on accurate appraisal of all available information. A degree of automated data analysis may assist clinicians navigate the current ‘data deluge’ and make the best informed decisions. The recent success of artificial intelligence and machine learning in various real-world applications suggests that such technology can be the key to unlocking the full potential of medical data and help with making real-time decisions
(Shipp et al., 2002; Tan and Gilbert, 2003; Caruana et al., 2015; Bouton et al., 2016; Chen et al., 2017; Hamesse et al., 2018).The high data density of the ICU is ideal for applying machine learning methods to assist with clinical decision making. All medical decision making is predicated on the prediction of future outcomes. Especially relevant to the ICU are predictions surrounding patient mortality. As a consequence, several machine learning studies published over the course of the last years focused on this task (Joshi and Szolovits, 2012; Celi et al., 2012; Ghassemi et al., 2014; Caballero B. and Akella, 2015; Meiring et al., 2018). Most of these studies concentrated on improving previously published measures of performance such as discrimination, specificity and sensitivity.
Deploying machine learning solutions in ICUs to support life/death decision making is challenging and requires high prediction accuracy. Artificial neural networks (ANNs) are powerful machine learning models that have been successful in several highly complex real-word tasks (Collobert et al., 2011; Boulanger-Lewandowski et al., 2012; Bojarski et al., 2016; Silver et al., 2016). The non-linearity of ANNs allows them to capture complex non-linear dependencies, a quality which often results in high predictive performance. Despite widespread success, predictions from ANNs lack interpretability. Instead, they often function as a black box. For example, after training an ANN on the task of outcome prediction it is difficult to determine which input features are relevant for making predictions. This is highly undesirable in the medical domain — making potentially life-changing decisions without being able to clearly justify them is unacceptable to both clinicians and patients. As a consequence, the application of ANNs in practice has been limited. Advancing the interpretability of such networks is a key step towards increasing their impact in healthcare applications.
In this work, we propose an interpretable machine learning model based on a Bayesian neural network (BNN) (MacKay, 1992; Hinton and Van Camp, 1993; Blundell et al., 2015; Hernández-Lobato and Adams, 2015; Louizos et al., 2017; Ghosh and Doshi-Velez, 2017) for outcome prediction in the ICU. Our proposed method offers not only the flexibility of ANNs but also interpretable predictions in terms of feature importance — inspecting the model parameters directly shows which features are essential for prediction and which are considered irrelevant. We choose a BNN as the underlying architecture because it explicitly models the uncertainty in a dataset, as well as in the model parameters and predictions. This replicates the inherently probabilistic nature of clinical decision making. We further utilize sparsity inducing prior distributions to allow for feature selection.
We propose a novel model based on a BNN that automatically selects relevant features for the prediction task at hand. In particular, we propose the use of tied sparsity inducing prior distributions, where the same sparsity prior is shared among all weights connected to the same input feature. As a consequence, after training the model, some weights connected to an input feature will be close to zero, indicating that the corresponding input feature is not relevant for performing prediction. For the prior distribution we use the horseshoe prior because of its sparsity inducing and heavy-tailed nature. Moreover, we derive a variational inference method to estimate the proposed model. We verify the effectiveness of our method using a synthetic dataset and apply it to two real-world ICU cohorts: MIMIC-III
(Johnson et al., 2016) and CENTER-TBI (Maas et al., 2014).We contribute a machine learning model for outcome prediction in the ICU which provides fully probabilistic predictions as well as insights into the relevance of input features. Such results can help clinicians with making treatment decisions and communicating with patients’ families. The predictions and insights can further be used for resource allocation purposes, for example, to decide whether a patient should be discharged in order to accommodate a new patient. More importantly, our method provides insights into how predictions are made. In particular, it allows to determine which medical measures are relevant for the task of outcome prediction. This is an important advance because model interpretability is essential if critical decisions based on diagnostic support systems are to be accepted by both clinicians and the public.
The rest of this work is organized as follows. We review related literature in Section 2 and present technical details of our proposed BNN architecture in Section 3. We evaluate our proposed method on real-world ICU cohorts: MIMIC-III (Section 4.2) and CENTER-TBI (Section 4.3). The cohorts and the experimental results are presented in Section 4.
In linear models sparsity is typically induced using a suitable prior distribution over the model parameters. The well known LASSO (Least Absolute Shrinkage and Selection Operator) (Tibshirani, 1996) produces a sparse estimate of the parameters in a linear model, by regularizing the
norm of the parameter vector towards zero. A Bayesian interpretation of LASSO is discussed in
Park and Casella (2008), showing that theregularizer can be interpreted as a Laplace prior over the parameters, and that the LASSO estimate is equivalent to a maximum a posteriori estimate of the linear coefficients given data.
Sparsity inducing prior distributions beyond the Laplace distribution have been proposed to improve feature selection results. As a prevalent example, the spike and slab prior places a mixture distribution on each parameter, comprising a point mass distribution at zero (the spike) and an absolutely continuous density (the slab) (Mitchell and Beauchamp, 1988). Since its introduction it has been employed in several settings (Ishwaran et al., 2005; Carbonetto et al., 2012; Scheipl et al., 2012; Hernández-Lobato et al., 2015a; Malsiner-Walli and Wagner, 2018)
. However, the spike and slab prior requires a careful choice of the mixture weights and the variance of the “slab”. Another popular choice for introducing sparsity is the horseshoe prior
(Carvalho et al., 2009), which assigns a half-Cauchy prior over the variance of the Gaussian prior over the parameters. The heavy tail of the horseshoe distribution allows coefficients associated with important features to remain large, while at the same time the tall spike at the origin encourages shrinkage of other parameters. Compared to the spike and slab prior, the horseshoe prior is more flexible because no mixing proportions need to be set. Furthermore, Bhadra et al. (2017) showed that the horseshoe prior is computationally more efficient. Therefore, in the rest of this work we employ the horseshoe prior as the sparsity inducing prior for better feature selection.Less work has been directed towards the application of sparsity inducing priors in non-linear models. One of the first approaches was Automatic Relevance Determination applied to BNNs (MacKay et al., 1994), which fits the prior variance for each individual parameters by maximizing the marginal likelihood. However, this approach fails to scale to large datasets as it involves the inversion of large matrices.
Louizos et al. (2017) and Ghosh and Doshi-Velez (2017) applied a horseshoe prior to prune inactive hidden units from BNNs, thereby achieving better compression. The models were trained using variational inference (Jordan et al., 1999; Beal, 2003)
and the gradient of the loss function was approximated using unbiased Monte Carlo estimates. Our work uses the same inference techniques, however, we focus on selecting input features, not hidden units.
ICU clinicians are faced with a large variety of data types including EHRs, monitoring data and diagnostic data. The heterogeneity and volume of this data, along with short timescales over which clinical changes can occur, makes the ICU setting a good test case for many prediction problems in healthcare. Over the past years researchers have explored the application of artificial intelligence to ICU tasks, like the prediction of outcomes or length of stay (Desautels et al., 2016; Churpek et al., 2016; Vranas et al., 2017; Nemati et al., 2018; Meiring et al., 2018).
Because of its central importance to patients and clinicians, outcome prediction is one of the most widely studied tasks within this field. Models can be divided into two categories: those using only static features (Knaus et al., 1985; Le Gall et al., 1993; Lemeshow et al., 1993; Elixhauser et al., 1998; Steyerberg et al., 2008) and those utilizing information about the temporal evolution of features (Joshi and Szolovits, 2012; Ghassemi et al., 2014; Caballero B. and Akella, 2015; Harutyunyan et al., 2017; Che et al., 2018).
Most approaches based on static features model only linear relationships or rely on manual feature engineering. Manual feature engineering scales poorly, and prevents models from automatically discovering patterns in the data. Linear models are easy to interpret, because the importance of input features can directly be inferred from the magnitude of the associated model coefficients. This is appealing for transparent clinical decision making and highly desirable for real-world applications for two reasons. Firstly, decisions without a justification are unacceptable both medicolegally and to clinicians and patients. Secondly, interpretability allows the model to be queried to gain novel insights into data which may be biologically hypothesis-generating. However, the capacity of linear models is limited. In most real world problems the relationship between input features and target values is non-linear or may involve complex interactions between predictors. Consequently, more powerful approaches are needed to model such data well.
In this work, we propose a model for mortality prediction named HorseshoeBNN. In contrast to previous work (Joshi and Szolovits, 2012; Celi et al., 2012; Ghassemi et al., 2014; Caballero B. and Akella, 2015) our model is able to both capture non-linear relationships and learn which input features are important for prediction, thereby making the model interpretable.
In this section, we describe our proposed BNN architecture for mortality prediction, with a horseshoe prior that enables feature selection for better interpretability. We first revisit the BNN, a type of ANN which explicitly incorporates uncertainty by introducing distributions over the model parameters (Section 3.1). The specific form of prior distribution we employ, the horseshoe prior, is discussed in section 3.2. It induces sparsity in the first layer of the BNN. Section 3.2 describes the architecture of our model, the HorseshoeBNN. Section 3.3 addresses the computational methods used to implement it.
Given an observed dataset , we want to determine a model that fits the data well and generalizes to unseen cases. In the context of the mortality prediction task studied in this work, the vector consists of different medical measurements and
is a binary indicator of the outcome for a specific patient. In parametric modelling a popular model for such tasks is the ANN
(Rumelhart et al., 1986; Goodfellow et al., 2016), which defines highly non-linear functions by stacking multiple layers of simple non-linear transformations. As an example, a fully connected ANN with
hidden layers defines the function in the following way:(1) |
Here represents an activation function
which is usually a simple non-linear transformation, e.g. sigmoid or ReLU
(Nair and Hinton, 2010). Depending on the task, the non-linearity for the output layer may be an idenity function for regression, or a sigmoid/softmax function for binary/multiclass classification. Bias terms can also be included in each layer by defining . In the rest of this work we use to represent all parameters of an ANN, and we denote the function defined by Equation (1) as to emphasize the dependence of the function values on .Instead of directly predicting the response with a deterministic function , BNNs start from a probabilistic description of the modelling task, and estimate the uncertainty of the parameters given the data. Concretely, the network parameters
are considered random variables, and a prior distribution
is selected to represent the prior belief of their configuration. Assuming that the observed data is independent and identically distributed (i.i.d.), the likelihood function of is defined as(2) |
where, in case of a binary classification task like the one presented in this work, the label is a scalar, and
(3) |
For regression tasks, we have After observing the training data , a posterior distribution of the network weights is defined by Bayes’ rule
(4) |
This posterior distribution represents the updated belief of how likely the network parameters are given the observations. With the posterior distribution one can predict the response of an unseen input using the predictive distribution:
(5) |
The prior distribution
captures the prior belief about which model parameters are likely to generate the target outputs
, before observing any data. When focusing on feature selection, sparsity inducing priors are of particular interest. In this work, we use a horseshoe prior (Carvalho et al., 2009), which in its simplest form can be described as(6) |
where
is the half-Cauchy distribution and
is a scale parameter. The probability density function of the horseshoe prior with
is illustrated in Figure 1. It has a sharp peak around zero and wide tails. This encourages shrinkage of weights that do not contribute to prediction, while at the same time allowing large weights to remain large.For feature selection we propose a horseshoe prior for the first layer of a BNN by using a shared half-Cauchy distribution to control the shrinkage of weights connected to the same input feature. Specifically, denoting as the weight connecting the -th component of the input vector to the -th node in the first hidden layer, the associated horseshoe prior is given by
(7) |
The layer-wide scale tends to shrink all weights in a layer, whereas the local shrinkage parameter allows for reduced shrinkage of all weights related to a specific input feature . As a consequence, certain features of the input vector are selected whereas others are ignored. For the bias node we use a Gaussian prior distribution. It is important to note that a simpler solution which keeps all feature weights small but does not encourage significant shrinkage would not be sufficient to perform feature selection — large weights in deeper layers of the network might increase the influence of irrelevant input features.
We model the prior of the weights in the second layer of the HorseshoeBNN by a Gaussian distribution, which prevents overfitting
(Blundell et al., 2015). The complete network architecture is given in Figure 1. Although we perform our experiments with a single hidden layer, the model can easily be enlarged by adding more hidden layers with Gaussian priors.A direct parameterization of the half-Cauchy prior can lead to instabilities in variational inference for BNNs. Therefore, we follow Ghosh and Doshi-Velez (2017) to reparametrize the horseshoe prior using auxiliary parameters:
(8) |
After adding the auxiliary variables to the Horseshoe prior, the prior over all the unobserved random variables is
For most BNN architectures both the posterior distribution and the predictive distribution are intractable due to a lack of analytic forms for the integrals. To address this outstanding issue we fit a simpler distribution and later replace with in prediction. More specifically, we define
(9) |
and use factorized Gaussian distributions for the weights in upper layers:
To ensure non-negativity of the shrinkage parameters, we consider a log-normal approximation to the posterior of and , i.e.
(10) |
In the horseshoe prior (see Eq. 6) the weights and the scales and are strongly correlated. This leads to strong correlations in the posterior distribution with pathological geometries that are hard to approximate. Betancourt and Girolami (2015) and Ingraham and Marks (2016) show that this problem can be mitigated by reparametrizing the weights in the horseshoe layer as follows:
(11) |
and equivalently, parametrizing the approximate distribution as
(12) |
Because the log-likelihood term does not depend on or , one can show that the optimal approximations and
are inverse Gamma distributions with distributional parameters dependent on
(Ghosh and Doshi-Velez, 2017).We fit the variational posterior by minimizing the Kullback-Leibler (KL) divergence . One can show that the KL divergence minimization task is equivalent to maximizing the evidence lower-bound (ELBO) (Jordan et al., 1999; Beal, 2003; Zhang et al., 2018)
Since the ELBO still lacks an analytic form due to the non-linearity of the BNN, we apply black box variational inference (Ranganath et al., 2014)
to compute an unbiased estimate of the ELBO by sampling
. More specifically, because thedistribution is constructed by a product of (log-)normal distributions, we apply the reparametrization trick
(Kingma and Welling, 2013; Rezende et al., 2014) to draw samples from the variational distribution: Furthermore, stochastic optimization techniques are employed to allow for mini-batch training, which enables the variational inference algorithm to scale to large datasets. Combining both, the doubly stochastic approximation to the ELBO iswhich is used as the loss function for the stochastic gradient ascent training of the variational parameters .
In this section, we focus on evaluating our method on two real-world ICU cohorts: MIMIC-III (Johnson et al., 2016) and CENTER-TBI (Maas et al., 2014). We also present a validation study on a synthetic feature selection dataset in appendix A, showing that our approach can fully recover the ground truth set of features. The two ICU cohort evaluations share the same experimental set-up which is summarized below. Details of the two cohorts and experimental results are presented in Section 4.2 and 4.3 separately.
We compare our proposed HorseshoeBNN with three related models for outcome prediction in ICUs. In this section we summarize all evaluated models.
LinearGaussian: a linear model with a Gaussian prior distribution on all weights (Bishop, 2006). This is the most commonly used Bayesian model for prediction tasks.
GaussianBNN: a standard BNN with a Gaussian prior distribution on all weights (Blundell et al., 2015).
LinearHorseshoe: a linear model with a horseshoe prior distribution on all weights (Carvalho et al., 2009). This model extends the LinearGaussian model by using sparsity inducing prior distributions for performing feature selection.
HorseshoeBNN: our novel extension of the GaussianBNN with a tied horseshoe prior distribution on the weights in the first layer. This enables the model to perform feature selection.
All models are trained until convergence using 10 fold cross-validation and the ADAM optimizer (Kingma and Ba, 2014)
. We use 50 hidden units for the MIMIC-III cohort and 100 hidden units for CENTER-TBI. The full list of hyperparameter settings can be found in Appendix
B. The code for the experiments will be available at (link to be added).MIMIC-III (‘Medical Information Mart for Intensive Care’) (Johnson et al., 2016) is a comprehensive publicly available intensive care database collected from tertiary care hospitals. The data includes information about laboratory measurements, medications, notes from care providers and other features. In total, the database contains medical records of 7.5K ICU patients over 11 years. Since its introduction, it has been used both by the medical and machine learning community for various analyses (Desautels et al., 2016; Choi et al., 2017; Che et al., 2018; Kurniati et al., 2018).
We preprocess MIMIC-III using code introduced in Harutyunyan et al. (2017), focusing on the task of mortality prediction. For all measurements we remove invalid feature values beyond the allowed range. The ranges are defined by medical experts as shown in Table 4 in Appendix B
. Because we do not focus on prediction based on dynamic features, we reduce the time series for each patient by computing the mean value of each feature in the first 48 hours of the patient’s stay in the ICU. We use mean imputation for missing values. The final cohort contains 17903 samples and 17 features. 86.5% of patients contained in the dataset survived (positive outcome), 13.5% deceased (negative outcome). A complete list of features is displayed in Table
4 in Appendix B.We compare the predictive performance in terms of error rate (lower is better), the area under the receiver operating characteristic curve (AUROC) (higher is better) and negative predictive log-likelihood (lower is better). The results are presented in Table
1. The non-linear models (GaussianBNN, HorseshoeBNN) show a small improvement compared to the linear models (LinearGaussian, LinearHorseshoe). The HorseshoeBNN performs on par with or slightly better than the GaussianBNN, because feature selection may potentially help to reduce overfitting. In addition, we present the confusion matrix for each method in Figure
2. As illustrated in Figure 2, the non-linear models perform better due to an improvement in correctly predicting the outcome of deceased patients. The fact that all models perform worse on deceased than surviving patients can be explained by data imbalance.Model | Error rate | AUROC | Negative Log-Likelihood |
---|---|---|---|
LinearGaussian | 0.129 0.008 | 0.807 0.013 | 0.321 0.013 |
GaussianBNN | 0.123 0.008 | 0.830 0.014 | 0.304 0.013 |
LinearHorseshoe | 0.130 0.008 | 0.807 0.013 | 0.320 0.013 |
HorseshoeBNN | 0.122 0.006 | 0.831 0.013 | 0.304 0.012 |
Results of the different models for the task of mortality prediction tested on the MIMIC-III cohort. The mean value and one standard deviation of each metric over 10-fold cross-validation is presented.
An important part of our results is constituted by the interpretability of the model predictions. Both the LinearHorseshoe model and our proposed HorseshoeBNN are trained to attribute relevance to the different input features. For the LinearHorseshoe model each input feature is associated with a single distribution. Therefore, the magnitude of the mean of each distribution, shown in the middle panel of Figure 3, directly reflects the relative importance of the corresponding input feature. In the HorseshoeBNN each input feature is associated with a vector of 50 distributions (the number of hidden units). We plot the average of the mean value of these distributions in the right panel of Figure 3.
When comparing the weight values we find that the two models agree on the importance of most features. Notable differences lie in the importance attributed to pH, systolic blood pressure and Glasgow coma scale.
Whereas the LinearHorseshoe model finds pH to be irrelevant, this is not the case for the HorseshoeBNN. This might be explained by the fact that the healthy range for pH is very narrow. Negative outcomes depend on pH values in a non-monotonic way — both too high and too low values are dangerous. Therefore, the non-linear model (i.e. the HorseshoeBNN) might be able to capture the importance of this feature better than the linear model.
For blood pressure, three values are recorded by the ICU: systolic, diastolic, and mean blood pressure. Any one of these values can be used to establish a baseline blood pressure for a given patient. The LinearHorseshoe model captures only one of them, namely diastolic blood pressure. In contrast, the HorseshoeBNN captures both the diastolic and systolic blood pressure. In the clinical context combining diastolic and systolic values allows to obtain additional information about the waveform of a patient’s blood pressure. This information can be of clinical relevance. Our results suggest that the HorseshoeBNN is able to recognize the importance of this additional information, whereas the LinearHorseshoe model is not.
The feature Glasgow coma scale total is selected only by the LinearHorseshoe model, but considered irrelevant by the HorseshoeBNN. It is important to note that the Glasgow coma scale components have a deterministic relationship: the total scale equals the sum of the verbal response, motor response and eye opening. This redundancy allows to model the data using any combination of the features. Two features, namely height and capillary refill rate are considered irrelevant by both models. However, it should be noted that these features are missing for a large percentage of patients in the cohort. Consequently, we cannot conclude anything about their importance.
Overall our results on feature selection suggest that, compared to a linear model, a non-linear model like our HorseshoeBNN can provide additional insights into the importance of input features.
The CENTER-TBI (Collaborative European NeuroTrauma Effectiveness Research in Traumatic Brain Injury) study is a longitudinal observational study that was conducted across Europe and Israel (Maas et al., 2014). The core study collected data from 5400 patients with a diagnosis of traumatic brain injury (TBI) for a period of up to two years post injury. Patients were divided into three strata with different clinical care paths. In this work we focus on the stratum with severest brain injury: patients seen in the emergency room and admitted to the ICU. The cohort contains a broad range of clinical data including baseline demographics, mechanism of injury, prehospital clinical course (e.g. vital signs and Glasgow coma scale), brain computed tomographic reports, any many other features.
We predict mortality based on the features listed in Table 6 using release 1.0 of the CENTER-TBI cohort. We remove the data of patients for which no outcome was reported. To address missing values we use zero-imputation for binary features and mean imputation for continuous and ordinal features, as suggested by clinicians. The final cohort contains 1613 samples and 39 features. 75% of patients contained in the cohort survived, 25% deceased. A complete list of features is displayed in Table 6 in Appendix B.
The resulting metric values of our experiment are summarized in Table 2 and Figure 4. We again observe that the HorseshoeBNN achieves slightly better metric values than the linear models (LinearGaussian, LinearHorseshoe). The GaussianBNN performs worse than all other models. This could be due to the large amount of noise in the data which the GaussianBNN might be modeling, thereby overfitting the data. In contrast, the HorseshoeBNN removes input features which makes the model less likely to overfit. The confusion matrices in Figure 4 show that BNN based models are better at predicting the outcome for deceased patients, which is consistent with our previous experiment.
Model | Error rate | AUROC | Negative Log-Likelihood |
---|---|---|---|
LinearGaussian | 0.185 0.024 | 0.871 0.041 | 0.393 0.057 |
GaussianBNN | 0.195 0.033 | 0.869 0.042 | 0.390 0.063 |
LinearHorseshoe | 0.180 0.025 | 0.874 0.041 | 0.383 0.046 |
HorseshoeBNN | 0.179 0.026 | 0.873 0.041 | 0.380 0.045 |
The relative importance of the different input features as determined by the LinearHorseshoe model and the HorseshoeBNN are shown in Figure 5. Features marked in bold are included in the IMPACT model (Steyerberg et al., 2008), a model commonly used for outcome prediction in TBI. As evident in the figure, there is little difference between the LinearHorseshoe model and the HorseshoeBNN. Both models attribute similar relevance to most features. Furthermore, we can observe a large overlap between features used in the IMPACT model and features considered important by our horseshoe models.
Interestingly, our models attribute little relevance to events of hypotension, although such events are considered important from a clinical perspective. A clinically likely explanation for this is that hypotension is typically unrecognised, leading to many false negatives for this feature. Similar to hypotension, both models consider the feature hemoglobin irrelevant for mortality prediction. This might be explained by the small role of hemoglobin in the IMPACT model. It is also possible that the small effect of hemoglobin is already accounted for by the features platelet count or international normalized ratio. Furthermore, both models consider certain features important that are not included in the IMPACT model. These features are: sex, heart rate, pH, international normalized ratio and absent basal cistern.
When comparing with the results for MIMIC-III we observe different results for the feature pH. Whereas both horseshoe models trained on CENTER-TBI determine pH to be important, this is not the case for MIMIC-III. For the latter only the non-linear model, that is, the HorseshoeBNN, selects this feature. This may be due to a difference between the patient groups contained in MIMIC-III and CENTER-TBI. This assumption is further supported by the distribution of pH values in the cohorts. A more detailed discussion of the findings, including figures is given in Appendix B.1. A further difference can be observed for the blood pressure features. This can again be explained by the heterogeneity of the datasets. Inspecting the relationship between systolic and diastolic blood pressure in both datasets shows that the two features are strongly correlated for CENTER-TBI but not for MIMIC-III. A more elaborate discussion can be found in Appendix B.1.
In this work we proposed a novel model, the HorseshoeBNN, for performing interpretable patient outcome prediction. Our method extends traditional BNNs to perform feature selection using sparsity inducing prior distributions in a tied manner. Our architecture offers many advantages. Firstly, being based on a BNN, it represents a non-linear, fully probabilistic method which is highly compatible with the clinical decision making process. Secondly, with our proposed advances, the model is able to learn which input features are important for prediction, thereby making it interpretable which is highly desirable in the clinical domain.
We worked closely with clinicians and evaluated our model using two real-world ICU cohorts. We showed that our proposed HorseshoeBNN can provide additional insights about the importance of input features. Together with its ability to provide uncertainty estimates, the HorseshoeBNN could be used to support clinicians in their decision making process. In view of the high-dimensional complex nature of medical data and the high relevance of outcome prediction in healthcare, our method could be useful not only for ICUs but in any medical settings. Our work illustrates how a close collaboration between computational and clinical experts can lead to methodological advances suitable for translation into tools for patient benefit.
In future work, we will extend our model to be able to work with the entire time series (Yoon et al., 2018) as dynamic prediction is a key area of interest in the ICU. Utilizing information about the evolution of features over time might not only improve predictive accuracy, but could provide additional insights about temporal changes of measurement values.
Moreover, we will use more sophisticated methods for missing value imputation to obtain better predictive performance (Ma et al., 2018; Little and Rubin, 2019). Finally, contemporary digital healthcare is fundamentally transdisciplinary and we would like to continue working with medical experts to explore how to deploy our method in a real-world clinical setting.
CENTER-TBI is supported by The European Union FP 7th Framework program (grant 602150) with additional funding provided by the Hannelore Kohl Foundation (Germany) and by the non-profit organization One Mind For Research (directly to INCF).
Variational algorithms for approximate Bayesian inference
. PhD thesis, University of London, 2003.Gram: graph-based attention model for healthcare representation learning.
In KDD Conference, pages 787–795. ACM, 2017.Probabilistic backpropagation for scalable learning of Bayesian neural networks.
In Proceedings of the 32nd International Conference on Machine Learning, pages 1861–1869, 2015.Expectation propagation in linear regression models with spike-and-slab priors.
Machine Learning, 99(3):437–487, Jun 2015b.Proceedings of the sixth annual conference on Computational learning theory
, pages 5–13. ACM, 1993.In this section, we verify the sparsity-inducing capacities of our LinearHorseshoe model. We repeat an experiment proposed in Hernández-Lobato et al. (2015b). In the experiment, a data matrix with 75 datapoints is sampled from the unit hypersphere. Target values are computed using , where represents Gaussian distributed noise with standard deviation 0.005. The 512-dimensional weight vector is sparse, that is, only 20 randomly selected components are non-zero.
Because the target values depend linearly on the data matrix , we use a linear model with a horseshoe prior to obtain an estimate of the weight vector. The estimate is given by the mean of the posterior distribution. We evaluate the model using the reconstruction error and find results similar to those reported in Hernández-Lobato et al. (2015b). Hyperparameter settings can be found in Table 3. The results of this validation experiment show that our model is capable of correctly reproducing sparse weight vectors.
Hyperparameter | Value |
Number of weight samples during training | 10 |
Number of weight samples during testing | 100 |
Batch size | 64 |
Learning rate | 0.001 |
Number of epochs |
2000000 |
Global shrinkage parameter of Horseshoe prior | 1.0 |
Local shrinkage parameter of Horseshoe prior | 1.0 |
The average reconstruction error over twenty realizations of the experiment is . This error is slightly higher but of the same order of magnitude as the error of reported in Hernández-Lobato et al. (2015b). The difference can be explained by the fact that we use variational inference to approximate the posterior distribution whereas Hernández-Lobato et al. (2015b)
use Markov chain Monte Carlo techniques, which are known to give a more accurate approximation of the posterior distribution. An example of a sparse signal and the reconstructed weights is shown in Figure
6.Feature | Unit | Min threshold | Max threshold |
---|---|---|---|
Height | cm | 250 | |
Temperature | °C | 49 | |
Blood pH | - | 6 | 8 |
Fraction of inspired oxygen | - | 0 | 1 |
Capillary refill time | seconds | 0 | - |
Heart rate | bpm | 0 | 300 |
Systolic blood pressure | mmHg | 275 | |
Diastolic blood pressure | mmHg | 150 | |
Mean blood pressure | mmHg | 190 | |
Weight | kg | 0 | 250 |
Glucose | mg/dL | 0 | 1250 |
Respiratory rate | number breaths per min | 0 | 150 |
Oxygen saturation | % | 100 | |
Glasgow coma scale eye response | - | 1 | 4 |
Glasgow coma scale motor response | - | 1 | 6 |
Glasgow coma scale verbal response | - | 1 | 5 |
Glasgow coma scale total | - | 1 | 15 |
Hyperparameter | Value |
---|---|
Number of weight samples during training | 10 |
Number of weight samples during testing | 100 |
Batch size | 64 |
Number of hidden units | 50 |
Learning rate | 0.001 |
Number of epochs | 5000 |
Standard deviation of Gaussian prior | 1.0 |
Global shrinkage parameter of Horseshoe prior | 1.0 |
Local shrinkage parameter of Horseshoe prior | 1.0 |
Category | Features |
---|---|
General | Age, gender, prior alcohol use, history of anti-coagulants |
Injury | Cause of injury, time of injury |
Condition on arrival | heart rate, respiratory rate, temperature, SpO |
Systolic blood pressure, diastolic blood pressure | |
Arterial O tension, CO tension, pH | |
Assessment of airway, breathing, circulation | |
Episode of hypoxia or hypotension | |
Neurological assessment | Glasgow Coma Score, pupil reaction |
Initial imaging | Marshall classification, depressed skull fracture |
subarachnoid hemorrhage, midline shift | |
Absent basal cisterns, extradural hematoma | |
Blood chemistry tests | Glucose, sodium, albumin, calcium, hemoglobin |
hematocrit, white blood cell count, C-reactive protein, | |
Platelet count, International normalized ratio | |
activated partial thromboplastin time, fibrogen |
Hyperparameter | Value |
---|---|
Number of weight samples during training | 10 |
Number of weight samples during testing | 100 |
Batch size | 64 |
Number of hidden units | 100 |
Learning rate | 0.001 |
Number of epochs | 5000 |
Standard deviation of Gaussian prior | 1.0 |
Global shrinkage parameter of Horseshoe prior | 1.0 |
Local shrinkage parameter of Horseshoe prior | 1.0 |
As discussed in section 4.3 we observed different results for the feature pH when comparing MIMIC-III and CENTER-TBI. Whereas both horseshoe models trained on CENTER-TBI determine pH to be important (see Figure 5), this is not the case for MIMIC-III. Here, only the non-linear model, that is, the HorseshoeBNN selects this feature (see Figure 3). This can be explained by the difference between the patient groups contained in MIMIC vs. CENTER-TBI. Whereas MIMIC contains a very heterogeneous group of patients, CENTER-TBI contains only patients that experienced some kind of trauma , and the outcome predictors will therefore be expected to be very different given this is a different disease. Figure 7 illustrates the distribution of pH in both datasets. When computing the KL-divergence between the distributions for patients that survived and patients that deceased we observe that the divergence is larger for CENTER-TBI. This might explain why both models determine pH to be important for CENTER-TBI, whereas only the non-linear model includes pH for MIMIC.
A further difference was observed for blood pressure. Both CENTER-TBI horseshoe models indicate that only systolic blood pressure is of relevance, whereas for MIMIC-III the HorseshoeBNN determined both systolic and diastolic blood pressure to be important. Again, this could be explained by the heterogeneity of the datasets. CENTER-TBI contains predominantly people with relatively isolated brain injuries not affecting the blood circulation. Therefore, blood pressure (absolute, e.g. systolic) is most important as this pressure perfuses the brain. In contrast, MIMIC-III includes a large number of patients in shock (e.g. with sepsis). The degree of this pathology, which would be less likely to occur in the CENTER-TBI patients, is clinically likely to be a strong determinant of outcome and would be expected to be reflected in a difference in diastolic blood pressure. This is further supported when inspecting the relationship between systolic and diastolic blood pressure in both datasets. Figure 8 shows that the two features are strongly correlated for CENTER-TBI but not for MIMIC. This might explain why only the systolic blood pressure is considered relevant for CENTER-TBI, but both the systolic and diastolic blood pressure for MIMIC.
Comments
There are no comments yet.