Interpretable Deep Neural Networks for Patient Mortality Prediction: A Consensus-based Approach

05/14/2019 ∙ by Shaeke Salman, et al. ∙ Florida State University 0

Deep neural networks have achieved remarkable success in challenging tasks. However, the black-box approach of training and testing of such networks is not acceptable to critical applications. In particular, the existence of adversarial examples and their overgeneralization to irrelevant inputs makes it difficult, if not impossible, to explain decisions by commonly used neural networks. In this paper, we analyze the underlying mechanism of generalization of deep neural networks and propose an (n, k) consensus algorithm to be insensitive to adversarial examples and at the same time be able to reject irrelevant samples. Furthermore, the consensus algorithm is able to improve classification accuracy by using multiple trained deep neural networks. To handle the complexity of deep neural networks, we cluster linear approximations and use cluster means to capture feature importance. Due to weight symmetry, a small number of clusters are sufficient to produce a robust interpretation. Experimental results on a health dataset show the effectiveness of our algorithm in enhancing the prediction accuracy and interpretability of deep neural network models on one-year patient mortality prediction.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks have demonstrated significant successes in many challenging tasks and applications [LeCun et al., 2015]. Even though they have been employed in numerous real-world applications to enhance the user experience, their adoption in healthcare and clinical practices has been slow. Among the inherent difficulties, the complexity of these models remains a huge challenge [Ahmad et al., 2018] as it is not clear how they arrive at their predictions [Zhang et al., 2018]. In medical practice, it is infeasible to only rely on predictions made by a black box model to guide decision-making for patients. Ideally, medical professionals should be able to point out under which conditions the model will work. Any incorrect prediction such as erroneous diagnosis may lead to serious medical errors, which is currently the third leading cause of death in the United States [Makary and Daniel, 2016]

. This issue has been raised and the necessity of interpretable deep learning models has been identified 

[Adadi and Berrada, 2018].

However, it is not clear how to improve the interpretability and at the same time retain the accuracy of deep neural networks. Deep neural networks have improved application performance by capturing complex interdependent relationships among input variables. To make the situation worse, these models are typically overparameterized, i.e., they have more parameters than the number of training samples [Zhang et al., 2016]. Overparametrization simplifies the optimization problem for finding good solutions [Allen-Zhu et al., 2018]; however, the resulting solutions are even more complex and more difficult to interpret. Consequently, interpretability enhancement techniques would be difficult without handling the complexity of deep neural networks.

Recognizing that commonly used activation functions (ReLU, sigmoid, tanh, and so on) are piece-wise linear or can be well approximated by a piece-wise linear function, such neural networks partition the input space into (approximately) linear regions. Due to weight symmetry 

[Hu et al., 2018], many of the different linear regions should be equivalent. In addition, gradient-based optimization results in similar linear regions for similar inputs as their gradient tends to be similar. By clustering the linear regions, we can reduce the number of distinctive linear regions exponentially and at the same time improve robustness. To further improve the performance, we train multiple models and use consensus among the models to reduce their sensitivity to incidental features (therefore avoiding adversarial examples) and also reduce overgeneralization to irrelevant inputs of individual models. We demonstrate the effectiveness of deep neural network models and the proposed algorithms on one-year mortality prediction in patients diagnosed with acute myocardial infarction (AMI) or post myocardial infarction (PMI) in MIMIC-III database.

The paper is organized as follows. In Section 2, we present generalization and overgeneralization in the context of deep neural networks and the proposed deep (, ) consensus-based classification algorithm. In Section 3, we describe a new interpretability method. Section 4 illustrates the effectiveness of the proposed algorithms in enhancing one-year mortality predictions via experiments. In Section 5, we review recent studies that are closely related to our work. Section 6 concludes the paper with a brief summary and plan for future work.

2 Generalization and Overgeneralization in Deep Neural Networks

Fundamentally, a neural network approximates the underlying but unknown function using , where is the input, and

is a vector that includes all the parameters (weights and biases). Given a deep learning model and a training dataset, there are two fundamental problems to be solved: optimization and generalization. The optimization problem deals with finding the parameters

by minimizing a loss function on the training set. Overparameterization in deep neural networks makes the problem easier to solve by increasing the number of good solutions exponentially 

[Wu et al., 2017]. Using the MNIST dataset [Lecun et al., 1998]

, we have trained a neural network three times using different initial values and obtained three different solutions. After that, we have interpolated in the parameter space spanned by these three solutions. Fig.

1 shows clearly that numerous good solutions exist and can be reached.

Figure 1: Good solutions can be reached from random initialization in overparameterized deep neural networks. The parameters are interpolated in the plane spanned by three good parameters based on varying and . The left panel shows the validation loss of the model in the plane while the right one shows the validation accuracy.
Figure 2:

Left: Bar plot that shows how the five models trained on the MNIST dataset agree on the overgeneralized samples (e.g., dog samples of the CIFAR10 dataset). Right: Bar plot that shows how the five models classify the adversarial examples generated by one of the models (m3); the bars denote the classification of the samples to true labels, adversarial labels and other labels, respectively.

Figure 3: Percentage of classified overgeneralized samples with (5,5) consensus. The samples are from the CIFAR10 dataset.

Since there are numerous good solutions, understanding their differences and commonalities are essential to developing more effective multiple-model based methods. Toward a systematic understanding of deep neural network models in the input space, one must consider the behavior of these models in case of typical, irrelevant and adversarial inputs (the inputs that are “computed” intentionally to degrade the system performance). As a representative example, we have trained five different deep neural networks on the MNIST dataset. To study systematically how the models respond to irrelevant images, we have used images from the CIFAR10 dataset as they do not contain valid handwritten digits; we have cropped the images and converted them to the same input format of MNIST. Fig. 2(left) shows how the five models agree on irrelevant samples by showing the maximum number of models that agree with each other over the same classification label for samples. It shows that the models respond (almost) randomly to such irrelevant inputs.

We have also generated adversarial examples using the fast sign algorithm [Goodfellow et al., 2014] as the search direction and binary search to find the minimum step size required to change the class label to another class. By perturbing the inputs using one of the models (m3), we investigate how the other models respond to those perturbed inputs, i.e., adversarial examples. Fig. 2(right) shows the classification results of the five models on the adversarial images which are generated by m3. Clearly, the other four models recognize the adversarial examples correctly for most of the perturbed examples.

2.1 Deep (n, k) Consensus Algorithm

As all the models classify training samples accurately, they generate similar linear regions and should behave similarly at training samples. Motivated by the overgeneralization and adversarial examples shown in Fig. 2

, we propose to use consensus among different models to differentiate extrinsically classified samples from intrinsically/consistently classified samples (CCS). Samples are considered to be consistently classified if they are classified by multiple models with a high probability in the same class. In contrast, extrinsic factors such as randomness of weight initialization or oversensitiveness to accidental features are responsible for the classification of extrinsic samples. As such random factors cannot happen consistently in multiple models, we can reduce them exponentially by using more models. To tolerate accidental oversensitiveness of a small number of models, we propose deep (

, ) consensus algorithm111While a preliminary version of the algorithm was introduced in [Salman and Liu, 2019], no justification was provided., which is given in Algorithm 1. Note that is a vector with one value for each class as it is computed class-wise. Essentially, the algorithm requires consensus among out of trained models in order for a sample to be classified; , a threshold parameter, is used to decide if the prediction probability of a model is sufficiently high.

0:  Trained models , ,…, , input , and parameter
1:  Apply each of the models to classify and retain the probabilities for each class as
2:  Compute by finding the class-wise minimal among top
3:  If ,
4:      Classify as the class with maximum
5:  Else
6:      Reject to classify (mark it as ambiguous)
7:  Endif
Algorithm 1 Deep (n, k) consensus-based classification
Figure 4: Accuracy and percentage when adversarial examples created by model-3 are used in accordance with (5,5) and (5,4) consensus. (a) Accuracy of the models with (5,5) consensus. (b) Percentage of the classified samples with (5,5) consensus. (c) Accuracy of the models with (5,4) consensus. (d) Percentage of the classified samples with (5,4) consensus

To illustrate the effectiveness of the proposed algorithm, Fig. 3 and 4 show the results on the irrelevant and adversarial samples. Majority of the irrelevant samples are rejected using a (5, 5) consensus algorithm as shown in Fig. 3; note that a (5,4) algorithm would also be effective even though it could not reject the ones where four models agree accidentally. Fig. 4 shows that by using a (5,4) consensus algorithm we can classify most of the adversarial examples correctly; the only ones are rejected due to that model m5 misclassified several samples. Clearly, as model m3 is oversensitive to the adversarial examples, a (5, 5) algorithm will reject all the adversarial images. Note that the proposed algorithm is different from ensemble methods [Ju et al., 2017], which are used to improve the performance of multiple models via voting. For example, ensemble methods will not be able to handle most of the irrelevant samples even though our algorithm is very effective (as shown in Fig. 3).

3 A New Interpretability Method

With a robust way to handle irrelevant and adversarial inputs, we propose a new method to interpret decisions by trained deep neural networks based on that such networks behave linearly locally and linear regions form clusters due to weight symmetry.

Figure 5: Outputs from the penultimate layer for classes 3 and 5 respectively.
Figure 6: Outputs from the penultimate layer for model-3 centered at a training sample (from class 5) along a random direction (left) and along the direction to another sample in the same class (right).

To visually illustrate the local linear behavior of such networks, Fig. 5 shows a two-dimensional plane centered at a sample for two classes for a typical trained deep neural network. It is clear that the outputs are indeed very close to linear planes.

The linear approximation, while simple, reveals rich deep neural network model behavior in the neighborhood of a sample. Interesting characteristics of the model can be uncovered by walking along certain directions from the sample. For example, adversarial examples are clearly evident along the direction shown in Fig. 6(left), where the classification changes quickly outside (where the given training sample is). On the other hand, Fig. 6(right) shows robust classification along this particular direction.

Figure 7: Change of accuracy and percentage of CCS samples with deep (, ) consensus on the one-year patient mortality prediction dataset. (a) shows the increase of intrinsic accuracy while (5,5) consensus. (b) shows the percentage of CCS samples while (5,5) consensus. (c) shows the increase of intrinsic accuracy while (5,4) consensus. (d) shows the percentage of CCS samples while (5,4) consensus.

More formally, under the assumption that the last layer in a neural network is a softmax layer, we can analyze the outputs from the penultimate layer. Using the notations introduced earlier, the outputs can be written as the following:


where is the vector-valued function. Since the model is locally close to linear, if we perturb an input (e.g., ) by a small value (i.e., ), then the Equation 1 can be approximated using the first order Taylor expansion around input .


Here, is the Jacobian matrix of function , defined as . Note that the gradient or the Jacobian matrix, in general, has been used in a number of methods to enhance interpretability (e.g., [Simonyan et al., 2013; Sundararajan et al., 2017]).

However, the Jacobian matrix only reflects the changes for each output individually. As classification is inherently a discriminative task, the difference between the two largest outputs is locally important. In the binary case, we can write the difference of the two outputs as:


where and are the first and second row of . In general, we need to analyze the difference between the top two or top outputs locally. The Jacobian difference vector essentially determines the contributions of changes in the features, i.e., the feature importance locally. This allows us to explain why the deep neural network model behaves in a particular way in the neighborhood of a sample. Note that the first part of Equation 3, i.e., is important to achieve high accuracy. However, the local Jacobian matrices, while important, are not robust. In addition, many of the local Jacobian matrices behave similarly due to weight symmetry [Hu et al., 2018]. To increase the robustness of interpretation and at the same time reduce the complexity, we propose to cluster the difference vectors of Jacobian matrices.

The Jacobian difference vectors can be clustered using K-means or other clustering algorithms. In this paper, we identify consistent clusters using the correlation coefficients of the Jacobian difference vectors of the training samples. To create a cluster, we first identify the pair that has the highest correlation. Then, we expand the cluster by adding the sample with the highest correlation with all the samples in the cluster already. This can be done efficiently by computing the minimum correlations to the ones in the cluster already for each remaining sample and then choosing the one with the maximum. We add samples iteratively until the maximum correlation is below a threshold. To avoid small clusters, we also impose a minimum cluster size. We repeat the clustering process to identify more clusters. Due to the equivalence of local linear models, the number of clusters is expected to be small. Our experimental results support this; see Section 4.4 for examples. Note that neural networks still have different biases at different samples, enabling them to classify samples with high accuracy with a small number of linear models. This could explain the apparent deep neural network paradox 

[Zhang et al., 2016]

: while deep neural networks have many parameters, they typically generalize well also, contradicting traditionally statistical learning theory 

[Vapnik, 1998]. We will investigate this further systematically.

We do clustering for each of the models. The clusters from different models can support each other with strong correlations between their means and can also complement each other by capturing different aspects of the data. Given a new sample, we estimate the Jacobian difference for each of the models and then compare that with the cluster means to identify the clusters that provide strongest support. This allows us to check that the new sample is not only classified correctly, but also its interpretation is consistent with the interpretation for training samples. With multiple models with clusters, the interpretation is more robust.

Specification Model 1 Model 2 Model 3 Model 4 Model 5
Neurons 400 500 600 400 550
Activation tanh tanh selu elu relu
Optimization SGD Adamax RMSprop Adam Adagrad
Bias zeros ones costant random normal random normal
Weights random uniform random uniform random normal random uniform random normal
Table 1: Implementation detail of five individual models
Model Accuracy ROC Precision Recall F-measure
1 0.8508 0.7891 0.8504 0.8508 0.8438
2 0.7395 0.7979 0.8364 0.7395 0.7502
3 0.8121 0.7695 0.8089 0.8121 0.8101
4 0.8471 0.7847 0.8462 0.8471 0.8399
5 0.8268 0.7784 0.8227 0.8268 0.8232
Table 2: Evaluation result of five individual models
Contribution Features List (mean, min, max, SD) **
Positive sodium (138.63, 118.18, 139.66, 3.32), alkaline phosphatase (ALP) (98.81, 13, 172, 65.80), alanine aminotransferase (ALT) (83.93, 44, 5509, 226.65), creatinine (1.37, 0.15, 15.70, 0.78), blood urea nitrogen (BUN) (26.11, 5.10, 140.66, 10.82), lactate dehydrogenase (LD) (456.39, 100, 5664, 176.72), age of admission (70.62, 18.70, 100.76, 13.34), heart rate (82.34, 36.84, 132.66, 12.31), aspartate aminotransferase (AST) (138.81, 2, 13511.7, 72.70), troponin T (2.30, 2.30, 24.80, 1.70), respiratory rate (19.35, 8.73, 42.69, 3.23), potassium (4.18, 4.23, 4.24, 0.36), cancer_Positive, cortisol (33.01, 22, 238.2, 4.45), cholestrol ratio (4.03, 4.03, 6.8, 1.95)
Neutral C-Reactive Protein (CRP) (52.03, 0.09, 273.75, 7.65), no genitourinary, marital_MARRIED, orthopaedic, other cardiac pacemaker implantation, no coronary bypass with cardiac catheterization or percutaneous Cardiac Procedure, endocrinology, hematological, marital_SINGLE
Negative white blood cell count (11.30, 0.45, 107.67, 4.78), hemoglobin (10.96, 4.31, 18.7, 1.53), chloride (103.81, 80.42, 125.61, 4.29), triglycerides (139.54, 1, 1983, 69.16), bilirubin (0.9131, 0.1, 31.13, 0.07), bicarbonate (24.82, 7, 47.57, 3.58), albumin (3.20, 3.20, 3.30, 0.47), systolic blood pressure (106.67, 20, 334.78, 21.49), creatine kinase (1.37, 9.5, 29579, 931.52), cancer_negative, cardiac valve and other major cardiothoracic procedures with cardiac catheterization, Brain natriuretic peptide (BNP), coronary bypass without cardiac catheterization, coronary bypass with PTCA, cardiac defibrillator implant without cardiac catheterization, coronary bypass without cardiac catheterization or percutaneous cardiac procedure, cardiac valve and other major cardiothoracic procedures without cardiac catheterization
Table 3: Some examples of positively, negatively and neutrally contributing features to the “died within a year” class. (** unusual values are not uncommon in electronic health record (EHR) data due to various reasons [Weiskopf and Weng, 2013]).

4 Experimental Results on One-year Mortality Prediction

4.1 Dataset

The medical information mart for intensive care III (MIMIC-III) database is a large database of de-identified and comprehensive health related data which is publicly available. This database includes fine-grained clinical data of more than forty thousand patients who stayed in critical care units of the Beth Israel Deaconess Medical Center between 2001 and 2012. It contains data that are associated with 53,432 admissions for patients aged 16 years or above in the critical care units [Johnson et al., 2016].

In this study, only those admissions with International Classification of Diseases, Ninth Revision (ICD-9) code of 410.0-411.0 (AMI, PMI) or 412.0 (old myocardial infarction) are considered. These criteria return 5436 records. We use both structured and unstructured data to train the deep neural network models. Structured data includes admission-level information about admission, demographic, treatment, laboratory and chart values, and comorbidities. The unstructured data is obtained from the discharge summaries associated with each admission (79 features). We use document embeddings to represent the average of word embeddings vectors (200 dimensional, trained with word2vec on the dataset Wikipedia+PubMed+PMC) for the words in the discharge summary of each admission.

4.2 Results from Individual Models

Five different deep neural network models are trained for the purpose of this work. Each of these models consists of three dense layers and a softmax layer for classification. Table 1 provides implementation details of these models.

The five models are trained using the same 90% of the records in the dataset that were randomly selected and evaluated on the remaining 10%. All the values are normalized to between 0 and 1. The evaluation results of the five models are provided in Table 2. The overall accuracy, while varying from model to model, is in general agreement with other methods.

4.3 Results from the deep (n, k) Consensus Algorithm

Here we illustrate the results using the proposed deep (, ) consensus algorithm. Fig. 7 illustrates its effectiveness on one-year mortality prediction task. It depicts the comparison between the results from individual models and the consensus of the models. Fig. 7(a) and 7(b) show that when the threshold is low (e.g., ), (5,5) consensus achieves around 90% accuracy which is substantially higher than any single model, with around 75% of the test samples classified. We also check the effect of the (5,4) and (5,3) versions on the same dataset and observe that (5,4) consensus (i.e., Fig. 7(c) and 7(d)) works best for this one-year mortality prediction dataset. For , it provides around 88% accuracy with around 89% of the test samples classified. In all the (, ) cases, we observe that the number of correctly classified samples among all the consistently classified ones increases with the threshold.

4.4 Interpretability Models

To systematically examine the proposed method, we first compute the Jacobian of the training samples and then compute the pairwise correlations. Fig. 8

shows an example. The mean correlation among the 11,963,386 pairs is 0.876 with a standard deviation value of 0.035. This deep neural network model has yielded highly consistent and robust interpretation for all the training samples. Note that the biases are different for training samples, enabling high classification accuracy.

Figure 8: Left: Distributions of the pairwise correlation coefficients between the difference of the Jacobian rows for a binary classification problem. Right: Average of the Jacobian difference vector.
Figure 9: Average of the Jacobian difference vector of highly correlated cluster set. The thicker smooth curve depicts the average on the whole set. The other curves show the average on each cluster of that particular set. Left: First cluster subset. Right: Second cluster subset

Since the correlations are consistently high, we calculate the average of the Jacobian difference vectors, which is shown in Fig. 8(right). It shows that the higher values (i.e., extreme values - leftmost negative or rightmost positive ones) of the average vector correspond to the most relevant and important features.

As described in Section 3, we group highly correlated clusters to achieve more robust interpretations. On this dataset, we have considered two subsets of highly correlated clusters among 10 representative clusters by our clustering algorithm. Fig. 9 depicts the averages of these subsets along with individual cluster averages. Since they are highly correlated, we notice similar behavior to the average on the subset for each of the clusters. Based on the sorted average of the first subset, we observe that leftmost features in the list have negative impact and rightmost features have positive impact on the positive class (“died within a year”). For the second subset, we notice almost identical features with the positive and negative impact on the positive class. To interpret validation samples, we look at the correlations with each subset.

As a result, we have found that a specific set of features contributes positively to the “died within a year” class while some other set of features contributes positively to the “did not die within a year” class. Also, some features show neutral behavior to the classification task, which are placed in the middle of the spectrum with slight tendencies towards either positive or negative ends of the spectrum. Due to space limitations, we illustrate the contributions of only selected features. We have excluded ethnicity- and religion-related features since most of them show neutral effect on the prediction outcome. Word embedding features are also excluded since they will not be useful in terms of interpretability of the results. Table 3 shows some examples of the most positive, negative and neutral features contributing to the positive class. As an instance, most of the categorical features related to undertaken treatment procedures are listed with negative impact on the positive class. In other words, these features have positive impact on the negative class (“did not die within a year”), which makes sense that more treatments should lead to a better outcome. Another example is the marital status, which is shown to act as neutral in the classification outcome, while cancer positive is shown to positively impact the positive class (“died within a year”). Patient age at the time of admission along with abnormal laboratory test values contributes positively to the positive class. The identified features are largely consistent with the features identified by other studies [Yang et al., 2019]. Also, those features with positive impact on the positive class are used in a conventional ICU mortality prediction tool APACHE-II [Mercado-Martínez et al., 2010].

5 Related Work

The lack of interpretability of deep neural networks is a limiting factor of their adoption by healthcare and clinical practices. Improving interpretability while retaining high accuracy of deep neural networks is inherently challenging. Existing interpretability enhancement methods can be categorized into integrated (intrinsic interpretation) and post-hoc (extrinsic interpretation) approaches [Adadi and Berrada, 2018; Du et al., 2018]. The integrated methods utilize transparent and intrinsically interpretable models [Melis and Jaakkola, 2018]. The methods in this family usually suffer from low performance as the transparent models can not approximate complex decisions in deep neural networks well. In contrast, the post-hoc interpretation methods attempt to provide explanations on an uninterpretable black-box model [Koh and Liang, 2017]. Such techniques can be further categorized into local and global interpretation ones. The local interpretation methods determine the importance of specific features to the overall performance of the model. The models that are interpretable are those that can explain why the system results in a specific prediction. This is different from the global interpretability approach, which provides a certain level of transparency about the whole model [Du et al., 2018]. The proposed method has the advantages of both local and global ones. Our method relies on the local Jacobian difference vector to capture the importance of input features. At the same time, clusters of the difference vectors capture robust model behavior supported by multiple training samples, reducing the complexity while retaining high accuracy.

6 Conclusion and Future Work

In this paper, we have proposed an interpretability method by clustering local linear models of multiple models, capturing feature importance compactly using cluster means. Using consensus of multiple models allows us to improve classification accuracy and interpretation robustness. Furthermore, the proposed deep (, ) consensus algorithm overcomes overgeneralization to irrelevant inputs and oversensitivity to adversarial examples, which is necessary in order to be able to have meaningful interpretations. Our results seem to resolve the deep neural network paradox, where models with many parameters generalize well, which we will investigate further systematically.