Self-explaining Neural Network with Plausible Explanations

Explaining the predictions of complex deep learning models, often referred to as black boxes, is critical in high-stakes domains like healthcare. However, post-hoc model explanations often are not understandable by clinicians and are difficult to integrate into clinical workflow. Further, while most explainable models use individual clinical variables as units of explanation, human understanding often rely on higher-level concepts or feature representations. In this paper, we propose a novel, self-explaining neural network for longitudinal in-hospital mortality prediction using domain-knowledge driven Sequential Organ Failure Assessment (SOFA) organ-specific scores as the atomic units of explanation. We also design a novel procedure to quantitatively validate the model explanations against gold standard discharge diagnosis information of patients. Our results provide interesting insights into how each of the SOFA organ scores contribute to mortality at different timesteps within longitudinal patient trajectory.



There are no comments yet.


page 1

page 2

page 3

page 4


How can we fool LIME and SHAP? Adversarial Attacks on Post hoc Explanation Methods

As machine learning black boxes are increasingly being deployed in domai...

Detection Accuracy for Evaluating Compositional Explanations of Units

The recent success of deep learning models in solving complex problems a...

Self-Explaining Structures Improve NLP Models

Existing approaches to explaining deep learning models in NLP usually su...

Meaningfully Explaining a Model's Mistakes

Understanding and explaining the mistakes made by trained models is crit...

This looks more like that: Enhancing Self-Explaining Models by Prototypical Relevance Propagation

Current machine learning models have shown high efficiency in solving a ...

Inducing Semantic Grouping of Latent Concepts for Explanations: An Ante-Hoc Approach

Self-explainable deep models are devised to represent the hidden concept...

How Much Can I Trust You? – Quantifying Uncertainties in Explaining Neural Networks

Explainable AI (XAI) aims to provide interpretations for predictions mad...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction


Figure 1:

The time-variant variables are passed through an imputation module

Cao et al. (2018)

and the imputed variables are propagated through recurrent layers to yield a latent (hidden) representation. The final latent representation, formed by concatenating the static variables (repeated across time dimension) is used to generate predicted explanations (predicted maximum SOFA organ scores within 24 hours). The explanations are combined by softmax-based attention mechanism and passed through a sigmoid activation function to yield the predicted probabilities of mortality.

Motivation: Deep learning models have been shown in many cases to be more powerful than conventional machine learning (ML) models because of their ability to detect complex patterns in the data and achieving high predictive accuracy Waljee and Higgins (2010). There has been an increasing trend to leverage these complex predictive models for high-stakes domains like healthcare and criminal justice Lahav et al. (2018). However, the inherent complexity of black-box deep learning models like neural networks makes it more challenging to explain model predictions especially for those unfamiliar. Particularly, in the clinical setting, explainability is critical in engendering trust amongst clinicians in the usage of deep learning based clinical decision support systems Tonekaboni et al. (2019).

Although there has been a plethora of research in recent years on the explainability aspect of deep learning models, defining the specifications for model explanation can be challenging because there exists no single widely-accepted definitionKarim et al. (2018). Model developers might be more interested in the working mechanism of the model for the purpose of debugging while end users like clinicians might focus on understanding the rationale behind the clinical predictions obtained as model output Xie et al. (2020). In this paper, we define explainability as the extent to which the model produces explanations about its predictions that are generally accepted as being understandable, useful, and plausible by subject matter experts. Recent work on explainability has mostly focused on post-hoc or posteriori explanations for deep learning approaches where a second (post-hoc) model is created to explain the first black box model. However, post-hoc model explanations often are not understandable by clinicians and are difficult to integrate into clinical workflow Rudin (2019). Therefore, the objective of this study was to design an intrinsically explainable deep learning model that renders meaningful explanations alongside its predictions.

One common approach to model explainability is to explain the model’s predictions in terms of the input features when the input is low-dimensional and individual features are meaningful Lundberg et al. (2020)

. For example, in logistic regression classifiers, the coefficient weights are typically interpreted as the contribution of individual features to the prediction output

Thomas et al. (2008). However, for high dimensional input, using raw features as units of explanation becomes infeasible – e.g. using individual pixel values to explain image classifications Kindermans et al. (2019). Recent works have proposed that learning higher level concepts or feature representations from the raw features can be a better choice for “units of explanations” Alvarez-Melis and Jaakkola (2018). In our work, we selected SOFA organ system scores as domain-knowledge driven plausible explanation units. SOFA is a widely-used, well validated mortality risk score with risk score for individual organ systems. Most in-hospital deaths are preceded by signs of organ failure; thus, these scores can provide insights into the reasoning behind mortality prediction.

Contributions: To date, limited research work exists in the explainability field which focuses on both improving model explainability and assessing the relevance of the generated explanations in the clinical domain. Our work differs from these recent efforts as we develop a novel, self explainable deep learning model that produces explanations in terms of high-level concepts grounded in domain expert knowledge and design a quantitative evaluative procedure to assess the quality of the model explanations. Our key contributions, described in details in the remainder of the paper can be summarized as follows: (i) formulate a multi-task neural-network-based pipeline for longitudinal mortality prediction in the Intensive Care Unit (ICU) setting using domain-knowledge driven SOFA organ-specific scores as auxiliary tasks, (ii)generate explanations in terms of how individual SOFA organ system scores contribute towards mortality at each point within the longitudinal trajectory of patients and (iii) propose a similarity metric to quantitatively validate the model explanations with gold-standard explanations in terms of the discharge diagnosis information of patients.


Figure 2: A: Explanation-attention visualization within the longitudinal trajectory of a single patient. The x axis represents time (hours into ICU admission) and the plotted values represent either the ground truth label (black) or the predicted value (red).The highlighting within each row indicates the degree of attention (attention weight) given to the explanations where dark hues corresponds to higher attention given to a particular organ system at a specific time point. B,C: AUROC (B) and AUPRC (C) curves evaluated on the training, validation and test set. D

. Histogram showing the variation of cosine similarity values across 283 patients in the held-out test set who experienced in-hospital mortality in the ICU.

2 Methodology

Dataset and Feature Processing: We used the Medical Information Mart for Intensive Care IV database comprising of de-identified clinical data of 22,944 ICU admissions between 2008-2019, of which 1877 (8.2%) experienced in-hospital mortality Johnson et al. (2016)

. Features include both static or time-invariant variables (e.g. demographics and prior comorbidity information) as well as time-series variables (e.g. laboratory test results and vital signs). We also calculated the SOFA organ-specific scores corresponding to the 6 organ systems :respiratory, cardiovascular, renal, hepatic, neurological and coagulation. Feature preprocessing of numerical variables steps include clipping the outlier values to the 1st and 99th percentile values and standardization using the RobustScaler method. Time-varying variables were aggregated into hourly time buckets using the median for repeated values. For each of the organ systems, the SOFA scores ranging from 0-4 are scaled to

). All categorical features were one-hot encoded.

Proposed Model architecture: fig:pipeline shows our proposed approach. The time-variant features are passed through an imputation module which iteratively imputes the missing values in the time series according to recurrent dynamics. Cao et al. (2018). The imputed time series values are then passed through a series of LSTM 111Long Short Term Memory layers to yield a latent representation. The static variables, repeated across time dimension are concatenated with the latent representation to create the final feature representation. The concatenated representation is then passed through a series of fully-connected hidden layers, followed by a sigmoid activation function to generate the predicted explanations or the SOFA organ-specific auxiliary task scores. The explanations are combined via softmax-based attention mechanism (i.e. weighted sum of the auxiliary scores) Bahdanau et al. (2014)

to generate the predicted in-hospital mortality within 24 hours. The novelty of our approach is using the SOFA-organ specific scores as both auxiliary layers to improve the final prediction outcome and also as atoms of explanation to provide insights into the clinical outcome. The attention weights signifies the contribution of each SOFA organ-system scores in deciding patient mortality. The model loss function is a weighted sum of the primary (mortality prediction), the auxiliary losses and the imputation loss eq:loss. Further discussion of the proposed architecture along with training details can be found in apd:appendix1

Evaluating Quality of Explanations: We design a similarity metric to validate the predicted explanations with gold standard explanations in terms of discharge diagnosis information of patients from the dataset. First, we create a mapping (apd:appendix2) between the actual SOFA organ-specific scores and the comorbidity information of the patients in terms of their ICD10 222International Classification of Diseases 10th revision) and ICD9 333International Classification of Diseases 10th revision codes. For example, any diagnosis code related to heart failure will be mapped to the SOFA cardiovascular component.


be a vector where each element represents the number of mapped ICD codes of a particular patient corresponding to the 6 SOFA organ-specific scores as obtained from the mapping described in apd:appendix2. We define the gold-standard explanation

by applying a SoftMax function on the individual counts eq:ground truth diagnosis.


If be the attention vector at timestep , then the similarity vector at can be calculated as the cosine similarity between and eq:similarity. Note that while is calculated only once for each patient, the attention vector is calculated at an hourly interval within the longitudinal patient trajectory. In order to analyze the most recent contribution of the SOFA organ scores before death, we selected the attention weights 6 hours prior to time of the patient’s death or hospital discharge for computing the similarity metric.

3 Results

Model performance was evaluated on a held out test set using AUROC 444

Area Under the Receiver Operating Characteristics Curve

(Figure 2B) and AUPRC 555Area Under the Precision-Recall Curve (Figure 2C). Figure 3 in apd:appendix3 shows the model performance on the auxiliary tasks. To validate the quality of the generated explanations with a gold-standard, we calculated the similarity metric of all the patients in the test set with who experienced in-hospital mortality (Figure 2D). Detailed insights into the low similarity values of some patients can be found in the Discussion section. Figure 2A shows the explanation attention visualization of the longitudinal trajectory of a single patient who died hours after ICU admission. We can observe that the model initially pays the maximum attention to the anticipated SOFA hepatic dysfunction till hours, followed by anticipated SOFA kidney (renal) dysfunction. As the predicted probability of mortality rises, the model is shown to pay more attention to anticipated respiratory, neurological, hepatic and renal organ failure, highlighting their contribution towards mortality.

4 Discussion and Future Work

This style of intrinsically explainable model requires developers to pre-define possible explanations for clinical predictions, which if not expressive enough might lead to a degradation in both performance and explainability. In our work, we have defined SOFA organ-specific scores as the units of explanations which assumes acute organ failures as potential risk factors for mortality. One potential limitation is that our explanations do not take into other serious conditions like cancer or sepsis infection. A patient having multiple SOFA organ dysfunctions may actually die from cancer, and the model will predict increased mortality risk due to the organ failures without identifying the exact cause of death. Although our model shows promise in identifying increased mortality risk due to organ failures, further work is required to enrich the list of defined explanations units and make our model findings more reliable for timely and accurate interventions.

The low similarity values of some patients in Figure 2D can also be intuitively understood by how the model pays attention to each of the SOFA organ systems within the longitudinal patient trajectory. For example, if the attention weights are evenly distributed across multiple systems 6 hours before death and if the ground truth explanation vector of the patient has high prevalence of diagnosis codes related to a single organ, the cosine similarity value according to eq:similarity will be low.

In this work, we did not compare our proposed model any state-of-the-art baseline method. We believe that while it might be trivial to generate performance baselines comparing prediction performance, it’s a non-trivial task to create suitable baselines for analyzing explainability and how the generated explanations are actually valid in the clinical domain. In this work we take a first step toward addressing a fundamental limitation of existing deep learning explainability research. Our proposed method not only generates explanations about model predictions but also makes an attempt to compare the explanations with a gold-standard. A systematic organization of explanation units related to the clinical outcome and fine-tuning the step of generating ground truth explanations can be an interesting scope of future work. Our next step is to compare both the performance and explainability aspect of our proposed model to state-of-the-art model explainability techniques like LIME Ribeiro et al. (2016), SHAP Lundberg et al. (2020) and Layer-wise Relevance Propagation (LRP) Montavon et al. (2019).


Appendix A Model Architecture and Training details

Architecture: Let be a multivariate time series of a particular patient which can be represented by a sequence of T observations , where . We denote the time invariant or static features of that patient as . The time-series features are passed through an imputation module which is a series of LSTMs and imputes the missing values according to recurrent dynamics Cao et al. (2018) (eq: RITS). The imputed variables are then passed through a series of recurrent LSTM layers to yield a latent representation (eq:LSTM. The latent representation is then concatenated with the static variables to form the final feature representation (eq:concat). The concatenated representation is then passed through a series of fully-connected hidden layers, followed by a sigmoid activation function to generate the auxiliary layer output or predicted explanations where represents the 6 SOFA organ system scores (eq:aux). and represent the weights and biases of the fully connected layers corresponding to the -th SOFA organ system. The auxiliary scores are combined via attention mechanism (i.e. weighted sum of the auxiliary scores) Bahdanau et al. (2014) to generate the in-hospital mortality within 24 hours (eq:primary). The attention weights for the -th SOFA organ are calculated by passing the concatenated latent representation through fully connected layers followed by a SoftMax activation function (eq:attention) where and are the weights and biases of the fully connected layers corresponding to the -th organ system where (.


The model loss function eq:loss is the weighted sum of 3 supervised losses : (i) binary cross-entropy (CE) loss between ground truth mortality label and predicted mortality , (ii) Mean-Squared Error (MSE) between the predicted auxiliary score probability and the maximum SOFA organ score within 24 hours and (iii) MSE between the imputed and true values of the time-series variables. denotes the total timesteps of the patient or the number of hours from ICU admission till discharge. Since the imputation loss is dominated by more frequent variables, there exists a high degree of loss imbalance between the imputation loss and the primary and auxiliary losses, which can lead to negative loss transfer and low prediction performance. To address this challenge, we selected higher weights for the primary and auxiliary loss () compared to the imputation loss ()


Model Training parameters: In order to address the class imbalance ratio in the dataset ( positive patients or those who suffered in-hospital mortality), we randomly select 4 survivor patients for every positive patient. The dataset was split into training, validation and test set (70:15:15), with the validation set used for early stopping. We trained the model using Adam optimizer with learning rate = 0.001, , and L2 regularization factor = . The size of all the 3 LSTM hidden layers and the fully connected connected layers were set as . The model was run for epochs with a batch size of and dropout rate = .

Appendix B ICD to SOFA Mapping

The predicted explanations of our proposed model are in the form of the 6 SOFA organ systems: respiratory, cardiovascular, hepatic, coagulation, renal and neurological. In order to create a ground truth explanation in terms of the SOFA organ systems, we mapped the ICD10 diagnosis codes of the patients into either of the SOFA organ systems following the Phecode Map 1.2 with ICD-10 Codes (beta) Wu et al. (2019) as shown in Table 1. For patients with missing ICD10 codes, we used the corresponding ICD9 666International Classification of Diseases, Ninth Revision (ICD-9) codes Organization et al. (1988). For example, the ICD10 codes corresponding to the SOFA respiratory system include all codes starting with J and the corresponding ICD9 codes are lexicographically between and .

Category ICD9 (Range) ICD10 (Range)
Respiratory 460-519 J
Cardiovascular 390-459 I
Hepatic 570-579 K70-K77
Coagulation 280-289 D50-D89
Renal 580-589 N17-N19
Neurological 290-389 G
Table 1: Mapping procedure of ICD10 diagnosis codes to the SOFA organ systems.

Appendix C Model Performance on Auxiliary tasks

Figure 3 shows the model performance on the auxiliary tasks corresponding to the 6 SOFA organ systems. For each of the organ systems, the SOFA scores ranging from 0-4 in increasing order of severity are scaled to . The results suggest that the model performs relatively well in predicting scores for all the organ systems.


Figure 3: Actual and predicted values of the auxiliary tasks corresponding to the maximum SOFA organ system scores within 24 hours. The dotted line represents the ideal scenerio where the predicted and actual values are same.