SAVEHR: Self Attention Vector Representations for EHR based Personalized Chronic Disease Onset Prediction and Interpretability

11/13/2019 ∙ by Sunil Mallya, et al. ∙ Amazon 26

Chronic disease progression is emerging as an important area of investment for healthcare providers. As the quantity and richness of available clinical data continue to increase along with advances in machine learning, there is great potential to advance our approaches to caring for patients. An ideal approach to this problem should generate good performance on at least three axes namely, a) perform across many clinical conditions without requiring deep clinical expertise or extensive data scientist effort, b) generalization across populations, and c) be explainable (model interpretability). We present SAVEHR, a self-attention based architecture on heterogeneous structured EHR data that achieves > 0.51 AUC-PR and > 0.87 AUC-ROC gains on predicting the onset of four clinical conditions (CHF, Kidney Failure, Diabetes and COPD) 15-months in advance, and transfers with high performance onto a new population. We demonstrate that SAVEHR model performs superior to ten baselines on all three axes stated formerly.



There are no comments yet.


page 4

page 12

page 13

page 14

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Clinicians record structured data such as diagnosis codes, vitals from lab tests and unstructured data such as clinical notes in electronic health records (EHR) system. Accurately predicting the progression of diseases using aforementioned data from EHR could allow clinicians and patients to make more informed choices, reduce costs, and decrease mortality and morbidity. But challenges in EHR data include, heterogeneity, temporal dependencies, sparseness and incompleteness while being high dimensional. [jensen2012mining, weiskopf2013defining, tran2014framework

]. Data-driven approaches for feature selection from EHR have been proposed to address these challenges. [

huang2014toward, lyalina2013identifying, wang2014unsupervised

]. An initial step in modeling a disease trajectory is to predict its onset. A variety of deep learning approaches to predicting disease onset have been explored including predictions of congestive heart failure

mallya2019effectiveness, Kidney Failure perotte2015risk, Dementia de2018unsupervised and Delirium wong2018development. High performance while a necessity, validation on a new population and interpretability are key aspects for adoption in a healthcare system. We answer these aspects by proposing SAVEHR which uses self-attention lin2017structured on structured EHR data to learn pairwise comorbidities to effectively predict disease onset, and generate personalized feature importance visualizations.

Related Work

: In recent years, Attention mechanisms have made substantial gains in conjunction with RNNs. Attention allows the network to focus on certain regions of data, while perceiving other regions with “low resolution”. As a consequence of that, it facilitates the interpretation of learned representations. We now see attention being applied to healthcare data (clinical notes and structured EHR) as well, To represent the behavior of physicians during an encounter a two-level neural attention model is used by

choi2016retain focusing on reverse time order of events. EHR events as a temporal matrix by cheng2016risk

and use CNN based architecture to predict onset of CHF and COPD, to obtain feature importance they aggregate weights of the neurons. To predict outcomes on ICU events, attention is used by

kaji2019attention and attention is used on clinical notes to detect adverse medical events by chu2018using. Multi-stage attention is used by Patient2Vec, where self-attention is applied within a sub-sequence of homogeneous features like medical codes followed by creation of aggregated deep representation to predict outcomes and generate personalized heatmap for a patient that can explain model predictions.

2 Cohort

We create one development and one external test cohort for each of the four chronic diseases (CHF, Kidney Failure, Diabetes Type II and COPD) from de-identified, anonymized, structured EHR data that are from two distinct patient populations referred to as P1 (the development cohort) and P2 (the test cohort). We use a 12-month observation window that’s between 27 and 15-months from the index date and a prediction window of 15-months. Since, the observation window was fixed across a relatively long time window (12-month), we aggregate frequency counts of diagnosis codes assigned for encounters across a time window (3-month time slices or quarters) similar to that in Choi2016UsingRN to facilitate temporal learning. The case-control design, index date selection along with time windows is explained in the Appendix B. We represent each patient’s data with static demographic features (gender, race, age) and sequence of diagnoses and procedures codes, termed as medical concepts. Feature sequence for a patient is denoted as , where denotes the quarter of interest, so the higher the subscript value, the closer the quarter is to the index date. The number of case and control patients for each disease is presented in Appendix 2.

3 SAVEHR Model Architecture

Figure 1:

SAVEHR Deep Neural Network

In this section, we present details and the architecture(Figure 1) of the proposed self-attention based SAVEHR neural network. There are three main components of this architecture: 1) self-attention layer for heterogeneous features, followed by 2) a Bi-GRU layer and 3) an MLP attention mechanism. In the following, we describe each component and their contribution to the classification task in detail.

Self-Attention with Heterogeneous features: Self-Attention relates elements at different positions from a single sequence by computing the attention between each pair of inputs, and . Non-categorical information such as age is converted into categorical feature by binning, while race and gender are integer encoded. The final feature representation for any given time slice is obtained by concatenating one-hot representations of all the above features into single vector form where represents homogeneous feature representation for patient in time slice , we denote its length as . The input is passed through an embedding layer, where an embedding is learnt for each of feature, represented by of , where is embedding dimension and is then fed into the self-attention layer. We compute attention for every feature with respect to other features in the via . The self-attention layer produces a vector of weights : where is a weight matrix with a shape of and is a vector of parameters with size . To capture nuanced interactions especially for ’s with long sequences of ’s, we perform multiple hops of attention. As an example, say we want different parts to be extracted from the , we extend into a matrix, note it as , and the resulting annotation vector becomes annotation matrix A. It’s formally represented in equation 1 where the softmax function is applied along the second dimension of its input. We compute the weighted sums by multiplying the annotation matrix and embedding output matrix resulting in .


To capture the longitudinal dependencies and understand the importance of each time slice for a given patient, we feed the sequence of encoded quarterly representations from the self-attention layer into a bidirectional GRU-based RNN with aggregated MLP-Attention, refer to Appendix E for details.


We use ten baselines categorized in to common baselines (Logistic Regression, Random Forest, Multi-Layer perceptron), Deep Learning Baselines(1D-CNN and Bi-directional GRU based) and attention based models. A wide variety of baselines were evaluated in order to understand the performance vs model complexity trade off. Baselines are described in depth in Appendix


4 Experiments

We perform a robust evaluation with 11 onset prediction models on four clinical conditions (CHF, Kidney Failure, Diabetes Type II and COPD) over three axes (Performance, Generalization and Interpretability). Given the imbalance in data, we consider the Area under the Precision-Recall Curve (AUC-PR) as the primary metric for performance [saito2015precision, davis2006relationship] and is reported in Table 1

. Standard deviations from three-fold cross validation is reported in Appendix


Experiment i) Across clinical conditions: For each of four conditions mentioned above, we created a training set, validation, internal test (P1) and external test (P2) and use AUC-PR as the primary metric for evaluating performance.

Experiment ii) Generalize across populations: Hospital systems can have variations in how diagnosis codes are assigned for each clinical visits as shown in many studies [burns2011systematic, quach2010administrative, jolley2015validity, vlasschaert2011validity], hence its essential for the model to be evaluated on different populations justice1999assessing. Several studies show that characterizing performance on a single population can be insufficient [collins2014external, bleeker2003external, konig2007practical, ivanescu2016importance]. Hence to evaluate, we pick the same trained model evaluated on test set (P1) and evaluate it on corresponding condition’s cohort in external cohort (P2) and report AUC-PR in Table 1 under section P2.

Experiment iii) Interpretablity: Non-linearity in deep learning based models help achieve better performance over linear methods, but may make model opaque to humans. In order to trust the model’s prediction, we believe alignment they should provide insights into why the model produced the result it did. We evaluate the interpretability of the models by generating both population level (Appendix K) and per patient feature importance visualizations for SAVEHR (Figure 2).

P1 P2
LR 0.2636 0.2622 0.2560 0.2465 0.5708 0.2425 0.3182 0.1958
RF 0.2729 0.3983 0.3154 0.6088 0.5211 0.2544 0.2421 0.1984
MLP 0.4536 0.4771 0.4637 0.5491 0.5361 0.2204 0.4127 0.4588
BG 0.4921 0.5969 0.4859 0.7594 0.712 0.279 0.5462 0.3731
CNN-1G 0.5122 0.5565 0.4284 0.7371 0.6853 0.1741 0.3509 0.4204
CNN-LK 0.5333 0.5809 0.4954 0.7413 0.6706 0.2839 0.4599 0.4503
BG-A 0.4978 0.6009 0.5125 0.7436 0.6725 0.3395 0.6502 0.3662
Dense-A 0.5109 0.5581 0.4745 0.7264 0.7099 0.3559 0.6523 0.385
CNN-1G-A 0.5043 0.5330 0.4976 0.7380 0.6988 0.3959 0.5843 0.4405
CNN-LK-A 0.5353 0.5464 0.5474 0.7734 0.7016 0.3743 0.6251 0.3562
SAVEHR 0.5464 0.6112 0.5174 0.7776 0.7541 0.5819 0.7074 0.4839
Table 1: Area under the curve (AUC-PR) performance across populations and conditions with 15-month prediction window

5 Results and Discussion

The SAVEHR model outperformed all baselines models on AUC-PR metric across all four conditions on the internal test set P1 (except Diabetes in P1 and the external test set P2 as well. In the external test set P2, SAVEHR gains ranged from 7-46% over the next best performing model as shown in 1.

External Test Cohort: A major strength of our work is that we used a formal external test cohort which is as large as most studies’ development cohorts to validate the model’s performance. Importantly, performance, as measured by AUC-PR, was higher (except Diabetes) in the external test cohort providing evidence that our architecture may generate models that generalize across cohorts. Although testing model performance seems to be an important criteria, a vast majority of published studies do not evaluate how their model transfers to a new population.

Interpretability: Predictive models are not, in general, intended to be explanatory, yet clinicians certainly desire an explanation of the model’s prediction particularly when that prediction is inconsistent with the clinician’s intuition. A powerful characteristic of the SAVEHR architecture is that it allows us to assign importance (or risk) scores to features and combination of features (Figure 2). While not a full explanation, we believe based on the findings in this study that it may be possible to provide the clinician with a summary and visualization that provides an indication of the underlying reasoning for model’s prediction for an individual patient. In addition, by exploring the importance scores across populations of patients such as those in a certain age category or with specific risk predictions the clinician may gain insight into which features contribute to risk in that category of patients.

Figure 2: Feature importance heatmaps

We examine the feature importance for two patients one with elevated Congestive Heart Failure (CHF) risk, Patient A (57%.) and Control A (13%) who correspondingly have similar characteristics (demographics and clinical encounters). We graphically illustrate the importance of pairwise feature interactions with color from deep blue to deep red indicating increasing importance. The features listed on the x-axis and the y-axis are the same for each panel, x-axis represents the ICD-9/10 code, while y-axis has descriptive labels for the codes. The mutual interactions are averaged, given there is no precedence for a feature over other. We observe that the patient identified as high risk has more interactions with high importance than the patients identified as low-risk, the interactions with high importance are multiple and diffuse: There are not one or two interactions but many that have high importance in patients with elevated risk, and many of the high importance interactions, but certainly not all, make clinical sense. A similar visualization is provided for one of the best performing attention based baselines in Appendix (Figure 10).

6 Conclusion

We provide a new self-attention based deep neural network architecture to extract interpretable and actionable information from heterogeneous, sparse time-series data from electronic health records. We provide a multitude of performance metrics on the models for a comprehensive comparison of the current state-of-the-art and our models. Our model yields SoTA results across four different clinical conditions on an external cohort of thousands of patients monitored for a year or longer. Finally, we provide samples from anonymized patients to identify the interpretability of prediction scores to demonstrate how clinicians can incorporate our risk scores into the clinical workflows. We believe the relative importance of these features and feature interactions with a appropriate visualization can improve clinician’s confidence in model predictions. Clinicians could utilize these predictions to target and modulate clinical interventions with greater precision.


Appendix A Population Statistics in P1 and P2

We describe the population statistics such as gender ratio and average age for all the four clinical conditions and present them in Figure 3.

Figure 3: Population statistics for P1 and P2

Appendix B Disease Cohorts

We create cohorts (Training, Validation and Test) on Population P1, and use P2 entirely as an external test set for four chronic diseases - Congestive Heart Failure (CHF), Kidney Failure, Diabetes Type II and Chronic Obstructive Pulmonary Disease (COPD).

Disease Training (P1) Validation (P1) Test (P1) External Test (P2)
case : control
CHF 14343 : 159567 793 : 8361 3916 : 41851 1259 : 5890
Kidney 8085 : 66045 447 : 3455 2216 : 17292 757 : 9351
Diabetes Type II 7674 : 53308 429 : 2781 2088 : 13961 3422 : 6997
COPD 11301 : 104719 641 : 5466 3107 : 27425 1767 : 5000
Table 2: Disease cohorts with train, validation and test sets for P1 and external test (P2) populations

Appendix C Index Date

Figure 4: Illustration of index date, observation and prediction windows

The case-control design within cohorts for each disease was created for patients who received care between 2015 and 2018. Incident cases for each condition was defined as patients between ages of 30 and 80 years of age for whom an ICD-9 or ICD-10 code representing the condition was recorded as an encounter diagnosis at least three times in a six-month period but never had any prior diagnosis for the condition. We defined the index date as the date of the first of the three qualifying encounters. We did not consider any of the data from the 3 months prior (buffer period) to the index date in order to avoid incorporating diagnostic data that had been obtained but not yet resulted in a diagnosis being recorded. Control patients were selected as those who had at least 5 encounters in a 2-year period but never had a diagnostic code for the condition being modelled recorded. The last encounter recorded in the system was chosen as index date for control patients. Figure 4 illustrates our use of 12-month observation window that’s between 27 and 15-months from the index date and a prediction window of 15-months. Since, the observation window was fixed across a relatively long time window (12-month), we aggregate frequency counts of diagnosis codes assigned for encounters across a time window similar to that in Choi2016UsingRN to facilitate temporal learning. Codes that had fewer than 50 occurrences in cohort were filtered out.

Appendix D End-to-End Data and Modeling Pipeline

HealtheDataLab is a big data processing platform built on Amazon EMR. The data, population health data with longitudinal patient records are ingested from Amazon S3. The end-to-end flow is as follows, an AWS Data Pipeline job orchestrates the transformation of data, the launch of an Amazon EMR cluster, and creates a data catalog along with a Hive metastore in Amazon RDS and AWS Glue. HealtheDataLab provides a Jupyter notebook running on an EC2 instance that connects to a spark pipeline on Amazon EMR. HealtheDataLab has custom packages like ontologies, FHIR support, and concepts mapping to empower data scientists to create patient cohorts in a very simplified manner. Once cohorts are created, they are stored in S3 as compressed numpy arrays. Then, the Amazon SageMaker machine learning job is kicked off with specified cohort location in S3, along with hyper-parameters to be optimized for. After the completion of the job, the best hyper-parameters are recorded and the job id is noted to run evaluation on secondary populations.

Figure 5: HealtheDataLab Workflow on AWS

Appendix E Aggregated Deep Representation with MLP Attention across quarters

To capture the longitudinal dependencies and understand the importance of each quarter for a given patient, we feed the sequence of encoded quarterly representations from the self-attention layer into a bidirectional GRU-based RNN, presented in Equation 2


where represents the output by the GRU for quarter . We use MLP Attention (or multi-layer perceptron attention) [58] on top of the BiGRU layer to obtain weighted representation of each quarter.The weight of each attention vector at GRU output , , is calculated as a normalized weighted sum,


where are hidden state vectors from the BiGRU cell, the attention network and an attention model. Once we obtain the attention weights, the vector representation aggregated for the patient across quarters is computed by:

Once we obtain the aggregated representation for a patient across quarters, we add a softmax layer for the final outcome prediction given by

Appendix F Baseline Models

Figure 6: Attention based Architecture

Logistic regression and Random Forest: We trained three commonly used baselines - logistic regression (LR), random forest (RF) and a multi-layer perceptron (MLP) with dropout. For the logistic regression and random forest, we use the implementation provided by scikit-learnscikit-learnwith no regularization.

Deep Learning Baselines (1D-CNN + BiGRU and BiGRU): Inspired by the success of 1D-CNN and RNN based architectures for clinical notes liu2018deep, we extend that to structured EHR data. The diagnosis codes for conditions and procedures that we collectively term as medical concepts, can be considered analogous to words in sentences. We use an embedding layer to encode the features into a continuous space. We use frequency of each code assigned in a time slice, and concatenate the frequencies for each code into the medical concept embedding as shown to be effective in mallya2019effectiveness. We feed the embedding data into a 1D-CNN first. Since the medical concepts are inherently not ordered, we use the equivalent of 1-gram, i.e a kernel size of 1 across feature embeddings for the 1D-CNN on each of the time slices. To exploit the longitudinal nature of EHR data, we feed the time slice aggregated representation from 1D-CNN into a bidirectional GRU (Bi-GRU) layer, we name this model CNN-1G. To measure the incremental value of 1D-CNN filters we create another baseline (BG) that uses only a Bi-GRU layer on top of the embedding layer. Demographic information may be static in nature, but is very critical to clinical decisions, hence we incorporate these features by concatenating to the Bi-GRU layer output from both the models described earlier.We also experiment with BiGRU instead of 1D-CNN and report performance on that.

Attention based Deep Learning Baselines (CNN-LargeKernel + BiGRU + Attention): To understand if attention could help in EHR based modeling, we add MLP attention chorowski2015attention

to the BiGRU layer for the baselines (CNN-1G & BG) described in the earlier section to enable them to focus on the most important time window. The approach above with 1D-CNN wouldn’t capture the interactions among features very effectively due to our kernel size of 1.Hence, ideally we’d like to compute for any 2,3 or n-grams of features their collective and relative importance. Given that the ordering of medical concepts within a given time window doesn’t matter, anything beyond a 1-gram kernel would require an ordering. To incorporate this and avoid the need for massive n-gram computation, we propose a novel baseline named CNN Large Kernel (CNN-LK), where the 1D-CNN kernel size is set equal to the number of input features, essentially giving us a weighted combination of all the input features. To understand if the 1D-CNN adds values, we also use another baseline where we replace the large kernel layer with a Dense layer of the same size. We note that, for the aforementioned architectures, we are unable to determine pairwise importance between any two features.

Appendix F Baseline Models

Appendix G AUC-ROC results for all conditions across populations

P1 P2
LR 0.8474 0.8586 0.8358 0.8628 0.7949 0.7441 0.7426 0.8187
RF 0.8187 0.8314 0.7980 0.8826 0.8138 0.7159 0.7038 0.6466
MLP 0.8466 0.7969 0.8170 0.8435 0.8167 0.8271 0.7861 0.8341
BG 0.8695 0.8677 0.8411 0.9129 0.8497 0.8066 0.7611 0.8325
CNN-1G 0.8677 0.8684 0.8247 0.9008 0.8724 0.8407 0.8611 0.8137
CNN-LK 0.8717 0.8709 0.8641 0.9144 0.8661 0.7983 0.8505 0.8437
BG-A 0.8751 0.8725 0.8467 0.9123 0.8724 0.8382 0.8922 0.8543
Dense-A 0.8695 0.8628 0.8392 0.9084 0.8769 0.8273 0.8726 0.7934
CNN-1G-A 0.8722 0.8575 0.8409 0.9055 0.8909 0.8424 0.8651 0.8362
CNN-LK-A 0.8752 0.8609 0.8421 0.9068 0.8811 0.8584 0.8913 0.8081
SAVEHR 0.8749 0.8728 0.8717 0.9160 0.9093 0.8616 0.8788 0.8369
Table 3: Area under the curve (AUC-ROC) performance across populations and conditions with 15-month prediction window

Appendix H Distribution of MLP attention weights across quarters

In order to understand the importance across time slices, we compute the average attention per time slice across entire test set P1 and report it below in 4. T4, the closest to the index date is the most prominent quarter.

Average Attention per timeslice t1 t2 t3 t4
0.22868084 0.2113695 0.20590293 0.35404657
Table 4: attention over quarters

h.1 MLP attention weights vs number of diagnosis counts

To assess the importance of attention with respect to the number of diagnosis in a given time-slice, we plot the average and standard deviation for the diagnosis counts. We observe that the model very low to zero attention to quarters without any diagnosis code. As the count increases, attention increases but the large error bars in both Figure 6(a) and Figure 6(b) suggest that its not always paying attention to time-slice with the most counts.

(a) For predicted case patients
(b) For predicted control patients
Figure 7: MLP attention scores vs number of diagnosis in a timeslice

Appendix I Performance metric graphs with error bars

In this section, we report the AUC-PR (Figure 8) and AUC-ROC (Figure 9) for population P1, with cross validation error bars. We observe that SAVEHR, CNN-1G-A and CNN-LK-A have very low standard deviations for both AUC-PR and AUC-ROC across all diseases when compared to the other models.

Figure 8: AUC-PR in P1 for all models with cross validation error bars
Figure 9: AUC-ROC in P1 for all models with cross validation error bars

Appendix J Example patient heatmaps for CNN-LK-A model

In section 5 of the paper, we provide an example visualization for the SAVEHR model. To contrast that, below we provide heatmap from the CNN-LK-A model on the same set of case and control patients.

Figure 10: Feature importance for CNN-LK-A

Appendix K SAVEHR Case population heatmaps in P1 and P2 for CHF

To understand the features that induce risk across the population as a whole, we generate averaged heat maps across all the case patients in P1 (Figure 11) and P2 (Figure 12). Noticeably, the top features in both of the populations differ, suggesting that the model is able to learn different characteristics and adapt.

Figure 11: SAVEHR heatmap visualization across quarters for all case patients in P1
Figure 12: SAVEHR heatmap visualization across quarters for all case patients in P2

Appendix L Feature importance tables for baselines

We report the feature importance as determined by averaging the importance scores predicted by the model across the predicted case patients.

Logistic Regression LR coefficient Diagnosis Code Description
1.42 90656 Flu Vaccine
1.207 G0378 Hospital Observation Service
1.096 735 Acquired hammer toe
0.934 361 Retinal defect
0.926 v54 Aftercare fracture arm
0.901 816 Closed fracture of middle phalanx of second finger of right hand
0.862 557 Enterocolitis
0.794 191 Malignant Neoplasm of Brain
0.788 041

Mycoplasma infection in conditions classified elsewhere

0.783 432 Chronic spont intraparenchymal hemorrhage
Random Forest coefficient Diagnosis Code Description
0.035 age Age
0.006 race Race
0.006 v58 Encounter for other and unspecified procedures and aftercare
0.005 401 Essential Hypertension
0.005 v76 Screening Colitis
0.005 36415 Blood Draw
0.004 v57 Care involving use of rehabilitation procedures
0.004 gender Gender
0.004 v70 General psychiatric examination
0.004 786 Chest wall pain
CNN 1 gram Importance score Diagnosis Code Description
0.060405 v76 Screening Colitis
0.043208 427 Atrial tachycardia
0.041183 v45 Status post lumbar surgery
0.03769 793 Abnormal Findings X-Ray Breast
0.02934 v57 Care involving use of rehabilitation procedures
0.020804 585 chronic renal failure
0.018625 562 Small bowel diverticular disease
0.017839 530 Cardiochalasia
0.016736 v10 Personal History of Malignant Neoplasm of Eye
0.016391 455 External hemorrhoids with complication
CNN LargeKernel Importance score Diagnosis Code Description
0.05082 569 Colostomy and enterostomy complications
0.48861 M06 Rheumatoid arthritis with negative rheumatoid factor (HCC)
0.38591 333 degenerative diseases of the basal ganglia
0.35872 R57 Cardiogenic Shock
0.34976 C95 Acute leukemia
0.31523 250 Diabetes mellitus TypeII
0.31516 I62 Nontraumatic subdural hemorrhage
0.31463 182 Malignant Neoplasm of body of uterus
0.29420 H25 Senile cataract of right eye
0.29093 M21 Limb deformity