Learning Hierarchical Representations of Electronic Health Records for Clinical Outcome Prediction

03/20/2019 ∙ by Luchen Liu, et al. ∙ 36

Clinical outcome prediction based on the Electronic Health Record (EHR) plays a crucial role in improving the quality of healthcare. Conventional deep sequential models fail to capture the rich temporal patterns encoded in the longand irregular clinical event sequences. We make the observation that clinical events at a long time scale exhibit strongtemporal patterns, while events within a short time period tend to be disordered co-occurrence. We thus propose differentiated mechanisms to model clinical events at different time scales. Our model learns hierarchical representationsof event sequences, to adaptively distinguish between short-range and long-range events, and accurately capture coretemporal dependencies. Experimental results on real clinical data show that our model greatly improves over previous state-of-the-art models, achieving AUC scores of 0.94 and 0.90 for predicting death and ICU admission respectively, Our model also successfully identifies important events for different clinical outcome prediction tasks

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 5

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The ever-growing massive Electronic Health Records (EHR) data expose an opportunity for large-scale data-driven health analysis and intelligent medical care. Predicting clinical outcomes, such as death and intensive care unit (ICU) admission, plays an important role in improving the performance of healthcare systems. For instance, accurate clinical prediction based on patients‘ existing medical records can enable advanced and timely medical intervention.

Clinical outcome prediction is challenging because it is hard to utilize rich temporal information encoded in the sequence of clinical events in EHR data choi2016using. In particular, EHR data usually consist of clinical events with irregular intervals Choi2016:Multi from heterogeneous sources, including patient health features (vital sign measurements, laboratory test results, etc), medical interventions (procedures, drug inputs, etc), and expert judgments (diagnoses, notes, etc) johnson2016mimic. The temporal order of these events is critical for predicting clinical outcome. For example, patient health features can be affected by previous medical interventions, and in turn determine subsequent medical interventions through expert judgments.

Conventional approaches that directly apply classic deep sequential models, such as recurrent neural networks pham2016deepcare (RNN) and convolutional neural networks nguyen2016deepr (CNN), usually fail to capture temporal dependencies in such long irregular event sequences Karlsson2016Predicting, lipton2016directly, as long-term dependencies can easily exceed the modeling capacity. To handle irregularly timed events, some extensions of the classic models have been developed, such as time-parameterized RNN and time-irregular CNN. However, their performance is still largely unsatisfying due to the limited ability to capture long-term dependencies.

This work aims to address the above challenges. We first make a key observation that, though the clinical events in EHR data can exhibit strong temporal patterns at a long time scale, the events occurring within a short time period usually do not have a definite order. Specifically, unlike word sequences in natural languages where word tokens are ordered by grammar rules, clinical events recorded in a short period of time are instead a series of events, such as clinical laboratory test results, that only reflect the patient’s status in different views. Therefore, direct temporal modeling of such short-range events as in previous work can introduce noise and harm the temporal predictive performance. Instead, local dependencies of these events should be modeled as event co-occurrence, and we can further select critical events from each of these short-range event groups as the basis for modeling the real temporal dependencies at a long time scale. A key difficulty to this end is that the criterion for distinguishing long-term temporal dependencies from the local co-occurrence of critical events in a short range can vary across different diseases and phases, especially in the irregular EHR data ghassemi2015multivariate, leebig.

To address the difficulties mentioned above, we propose a hierarchical neural network for clinical outcome prediction. Specifically, we adaptively segment irregular event sequences into sequential groups to distinguish short-range co-occurrence and long-term temporal dependencies as well as to learn hierarchical representations of the event sequence. At the low level, the model automatically identifies critical clinical events in each group and aggregates the events to form event group representations. At the high level, meaningful long-term dependencies of clinical event groups are captured in a sequence representation vector by a recurrent neural network. Compared to traditional methods, the proposed method has several advantages:

  • Our model can deal with the temporal irregularity of clinical event sequences by adaptively segmenting an event sequence into sequential groups.

  • Our model learns a hierarchical representation of long and irregular event sequences to capture long-term dependencies of clinical events.

  • The model is capable of discovering critical event groups as well as critical events in each group, through a temporal attention and event attention mechanism. This provides useful clinical insights for accurate prediction.

2 Related Works

2.1 Modeling EHR Data

Most existing works based on EHR data ignore irregular time intervals of clinical events in EHR so that clinical event sequences are regarded as word sequences Liu2015Temporal. For example, some approaches trained the semantic embeddings for the categories of clinical events for adverse drug event detection Henriksson2015Modeling. There is a work that proposes a multi-view learning approach that generalizes Canonical Correlation Analysis for an arbitrary collection of matrices involving missing data huddar2016predicting. These works make predictions based on the clinical events with regular time intervals, and cannot distinguish short-range order from long-term temporal order of different diseases and patients. Our work addresses the issue by adaptive segmentation of clinical event sequences.

As long-term temporal dependencies are hard to capture, many works use a small subset of the whole EHR information, to avoid dealing with the long clinical event sequences. Some works select a subset of the numerical clinical features (the numerical attributes of clinical events) in the EHR data according to the expertise of clinicians che2016recurrent. For instance, Ahmed only uses a set of 21 (temporal) physiological streams comprising a set of 11 vital signs and 10 lab test scores to predict ICU admission. Some works used graphical models to model patients’ health status caballero2015dynamically. Some techniques transform selected 99 time series features of all the EHR data into a new latent space using the hyper-parameters of multi-task GP (MTGP) models to model patient similarity choi2016using. Recently, RETAIN used two reversed recursive neural networks (RNN) generating attention variables of sequential international disease classification (ICD) code groups for the prediction choi2016retain. However, the codes are grouped by the fixed-length time slots for distinct patients and diseases, and local dependencies and long-term dependencies may be mixed up. But these works can lose significant information, due to the expert bias when selecting a limit fraction of all clinical features in EHR as the input of the models, and fail to provide new data-driven insights for better healthcare.

2.2 Clinical Outcome Prediction

The clinical outcome prediction problem is studied by many works. However many of them cannot take advantage of the temporal information in EHR data for prediction. Some of these studies used latent variable models to decompose bag-of-words free-text extracted from clinical event descriptions into meaningful features to predict patient mortality ghassemi2014unfolding. “Deep patient” miotto2016deep arranged all clinical descriptors (features) in a period of time in a sparse vector without temporal information and trained the deep representation of patients with a 3 layer de-noisy autoencoder for diagnosis. Some work studied how to diagnose and predict Alzheimer’s disease (AD) with a hybrid manifold learning for non-temporal clinical feature embedding and the bootstrap aggregating (Bagging) algorithm dai2016bagging. There is also a work model EHR data by factorizing the medical feature matrix into a latent medical concept mapping matrix and a concept value evolution matrix, and then they averaged all vectors in the evolution matrix to predict heart failure Zhou2014From. Our model learns the hierarchical representations of clinical event sequences to utilize the temporal information for clinical outcome prediction

3 Data and Task Descriptions

We give the notations and data descriptions of the predictive tasks in the following.

Clinical Events in EHR

A clinical event is a record in the database of EHR, which describes a clinical activity of a particular patient at a certain time. The events can be measurements of vital signals, injection of drugs, results of laboratory tests, and so on, which are summarized in table 1. Each clinical event has some attributes, including categorical attributes and numerical attributes. For example, the lab test event has 2 categorical features and 1 numerical feature: . The meaning of this event is the result of the Cholesterol test is , which reflects an abnormal health status. An episode of a patient EHR data is a clinical event sequence, which may consist of hundreds of clinical events.

Clinical Outcome Prediction

Clinical outcome prediction is to dynamically predict whether a clinical outcome will happen in 24 hours based on an episode of a patient. We aim to dynamically predict two outcomes in this work. In the first “death prediction task”, the outcome is death in hospital or discharge to home. In the second “ICU admission prediction task”, the outcome is clinical deterioration and ICU admission, or clinical stability.

event descriptions event name (top 2) frequency (rank-in-all) coverage (rank-in-all) frequency per patient
Chart events includ routine vital signs, ventilator settings, mental status, and so on. Heart Rate 5171250 (0.01%) 0.64 (0.16%) 173.4
SpO2 3410702 (0.01%) 0.479 (0.33%) 153.0
Input events are any fluids which have been administered to the patient: such as oral or tube feedings or intravenous solutions containing medications. 0.9% Normal Saline 2363812 (0.05%) 0.393 (0.66%) 129.2
Propofol 369103 (0.81%) 0.217 (1.45%) 36.5
Lab events contain all laboratory measurements for a given patient. HEMATOCRIT 881846 (0.22%) 0.976 (0.01%) 19.4
POTASSIUM 845825 (0.23%) 0.886 (0.05%) 20.5
Procedure events contains procedures for patients. Chest X-Ray 32723 (3.2%) 0.204 (1.52%) 3.44
EKG 13962 (4.35%) 0.167 (1.82%) 1.79
Output events are fluids which have either been excreted by the patient, such as urine output, or extracted from the patient. Chest Tubes CTICU CT 1 151766 (1.57%) 0.098 (2.67%) 33.2
Urine 107465 (1.93%) 0.075 (3.05%) 30.8
Table 1: Statistics of clinical events

Patient Cohort Setup

We set up two datasets from one real clinical data source, MIMIC-III johnson2016mimic (Medical Information Mart for Intensive Care III), which is a large, freely-available database comprising de-identified health-related data associated with over forty thousand patients who stayed in intensive care units of the Beth Israel Deaconess Medical Center between 2001 and 2012.

We extract 18192 kinds of clinical events with their attributes from the database to get event sequences of patients (the events with top frequency are listed in Table 1). The events whose frequency is less than 2500 are dropped out. And we also drop out the admissions, of which the time span from the beginning to the target clinical outcome is less than 36 hours. Each input of a sample is an episode of a patient clinical event sequence 24 hours before the target outcome. The statistics of final clinical event sequences in the two tasks are summarized in Table 2.

Dataset # of samples # of events Avg timespan
death 24301(8%) 20290879 3d 15h 58m
ICU admission 19451(21%) 14515198 4d 18h 31m
Table 2: Statistics of the datasets(the percentage in the second column is the positive sample rate)

4 Methodology

Figure 1: The overall architecture of our model. The original irregular event sequence is segmented into sequential event groups by the adaptive segmentation module. Then our model learns the hierarchical representation of the sequential event groups. In the low-level representations, each event group is represented as a vector by the event attention. In the high-level representations, the embedded sequential groups are modeled by general recurrent units (GRU) with inter-group temporal attention.

In this section, we introduce the technical details of our proposed model. Our model first segments the whole clinical event sequence into several event groups via the adaptive event sequence segmentation module. Then the model learns hierarchical representations of event sequences. with both event attention and temporal attention mechanisms. The architecture of our model is illustrated in Figure 1.

4.1 Adaptive Segmentation

To distinguish long-term temporal dependencies from co-occurrence of important events in short range, we adaptively segment an event sequence for a patient into sequential groups, according to the irregular record time of events. As events in the same group are exchangeable, sequential groups can avoid the influence of the short-range order noisy in clinical events. Moreover, sequential groups reduce the length of the sequences fed to RNN, which makes capturing long-term temporal dependencies easier.

We find segmentation points of an event sequence by minimizing the maximum time span of the resulting segments. Formally, given a event sequence , segmentation points can split the sequence into groups , where the event group is an episode of clinical events from time to time . And the time span of a group is defined as the time difference of the last event and the first in the group, namely , where is the record time of the event . So the optimal choice of the segmentation points can be found by minimizing the following:

where is the max number of groups and is the constraint to avoid the segmentation too fine-grained.

The adaptive segmentation is designed in a way of the combination of greedy method and binary search. We binary search the minimal upper bound of the maximum time span of all groups. And we then verify the searched upper bound of time spans by trying to greedily construct a solution satisfying the constraints of and the time span upper bound. The time complexity of the algorithm is , where is the length of the event sequence and is the difference of end time and start time of the sequence. We can regard this algorithm as an algorithm of linear complexity with a big consistent coefficient.

To represent clinical events with their attributes, we embed each clinical event into the low dimension space as a vector in the way described in the previous work liu2018learning. The representing vector (where N is the event embedding dimension) is the sum of event type vector (as basic event information) and event attribute encoding vector (as the description of event feature).

4.2 Hierarchical Representations with Attention Mechanisms

Based on the sequential event groups, the model can learn hierarchical representations to capture long-term temporal dependencies. In the low-level model, the model automatically identifies critical clinical events in each group via event attention mechanism and aggregates the events to form event group representations. In the high-level model, the meaningful long-term temporal dependencies of clinical event groups are captured by a recurrent neural network with temporal attention mechanism in the sequence representation. The hierarchical representations help to learn long-term temporal dependencies in the original event sequences.

4.2.1 Event Group Representation

To select the significant events in each group and compact events in the same group into one vector as the event group representation, the event attention mechanism is added in the low-level model.

Given sequential groups produced by the adaptive segmentation module , where , attention score of each event in the group is calculated by the event attention mechanism. The scalars are the event attention weights that govern the influence of event embeddings in the group .

We use a multi-layer perceptron (MLP) with one hidden layer to generate

based on the event embedding vector and the hidden state of the previous time as follows:

where .

where

is the hidden state of the previous gated recurrent units (GRU)  chung2014empirical, (which will be described in the following section)

is the latent layer of the event at time and and are parameters to learn. Notice that is the hidden layer dimension and is the GRU hidden state dimension.

The resulting attention scores reflect the importance of each event in a group according to the temporal context of the group. Events in the t-th group are weighted averaged with to get the group representation as the input to the -th GRU unit.

4.2.2 Sequence Representation

To spot the critical phases over the sequence for the final decision and capture long-term temporal dependency of event groups, gated recurrent units (GRU)  chung2014empirical equipped with temporal attention mechanism is employed as the high-level model.

where the function represents the recurrent unit, which use the previous hidden state and current input vector to update the hidden state.And represents all the parameters of GRU.

The vector contains the temporal attention weights of each group in the sequence. And we use a fully connected feedforward network to generate from the output of GRU at each time as follow:

where is the output matrix and is the parameters to learn.

The sequence representation is the weighted average of the output matrix . We use to predict the true label of the sequences.

where and are parameters to learn.

The cross-entropy loss function is used to calculate the classification loss of each sample as follows:

where is the input event sequence and is the label indicating whether the clinical outcome happens. And we can sum up the losses of all the samples in one mini-batch to get the total loss for back propagation.

5 Results and Discussions

5.1 Comparison Methods and Settings

We compare our model with popular models in the literature, which include

bag-of-words vector classifiers

(i.e. support vector machine scholkopf2001learning (SVM), logistic regression yalcin2011gis (LR), random forest liaw2002classification (RF)) and

deep sequential models, such as RETAIN choi2016retain and RNN leebig (implemented with GRU).

Due to that SVM, LR, RF cannot handle a sequence input, the event sequence is compressed into a 0-1 vector in which the i-th element indicates whether the i-th event happens, and then fed into SVM, LR, or RF to make the outcome prediction. We implement these bag-of-words vector classifiers using scikit-learn111https://scikit-learn.org/stable/.

Deep sequential models (i.e. RETAIN choi2016retain and RNN leebig) take the original event sequence as their inputs as described in section 4.1. We implemented our model and neural network based baselines with Theano 0.8.

The event embedding size is set to 32 while the hidden layer size is set to 64. The max number of groups is set to 32. When training the models, we used Adam Kingma2014Adam with the mini-batch of 32 samples and the “early stop”.

5.2 Evaluation Metrics

Metrics for binary labels such as accuracy are not suitable for measuring the performance on imbalanced datasets. Therefore, similar to the works Liu2015Temporal,choi2016retain, we adopt ROC curves (Receiver Operating Characteristic curves) and PRC (Precision-Recall curves) for evaluation metrics. Both of these two curves reflect the overall quality of predicted scores, according to their true labels. To get quantitative measurements, the area under ROC(AUC) and the area under PRC(AUPRC) are utilized.

5.3 Quantitative results

Models Death ICU admission
AUC AUPRC AUC AUPRC
SVM 0.7523 0.5154 0.7973 0.7074
LR 0.8843 0.5213 0.8734 0.7266
Random Forest 0.8644 0.5867 0.8389 0.8177
RETAIN 0.8967 0.5808 0.8693 0.8029
RNN 0.9038 0.6234 0.8636 0.8051
OURS 0.9428 0.7476 0.9016 0.8424
Table 3: performance of different models on death and ICU admission prediction tasks

Table 3 shows the AUC and AUPRC of different models on the death prediction and the ICU admission prediction tasks. From the results shown in Table 3, we can draw the following conclusions:

First, on the whole, deep sequential models(including RETAIN, RNN and the proposed model) outperform non-sequence models(including SVM, LR, and Random Forest) on both tasks, which suggests that temporal information is effective in the outcome prediction tasks.

Second, our model outperforms all the sequential models. For example, on the “ICU admission task”, the proposed model improves AUC by at least 3.4% and AUPRC by at least 3.0% compared other models on all tasks. The improvement verifies our claim that it’s more proper to capture temporal dependencies of clinical event sequences in a hierarchical way.

5.4 Ablation Studies

In this section, we perform ablation studies to examine the effects of our proposed techniques, namely the event attention mechanism, the temporal attention mechanism, and the adaptive segmentation module.

The first study over two attention mechanisms is performed on both tasks. Specifically, we re-train our model by ablating certain components:

W/O E-Attn, where no event attention is performed and a group representation is set as the average of the event embeddings in this group.

W/O T-Attn, where no temporal attention is performed and the sequence representation is set as the final output of the GRU.

Models Death ICU admission
AUC AUPRC AUC AUPRC
W/O T-Attn 0.9348 0.7181 0.8987 0.8400
W/O E-Attn 0.9170 0.6404 0.8930 0.8376
Full Model 0.9428 0.7476 0.9016 0.8424
Table 4: Ablation study over attention mechanisms

Results of attention mechanism ablation studies are represented in Table 4. We can see that both attention mechanisms contribute to the strong empirical results of our model represented previously. It is noteworthy that the event attention, one of the important part of hierarchical representations, plays a more critical role in our model compared to the temporal attention, especially on the death prediction task.

Besides the attention mechanisms, study over the adaptive segmentation is performed on the death prediction task. We re-train our models by replacing the adaptive segmentation module with the fix-length segmentation which splits the original sequence into groups of equal size events (except the last group). Group size of the fix-length segmentation a hyperparameter. Notice that the fix-length segmentation degenerates to no segmentation if group size is set to 1.

Figure 2: Ablation over the adaptive segmentation on the death prediction task. The fix-length segmentation splits the event sequence into groups of equal size(except the last one).

Figure 2 shows the AUC and AUPRC of the proposed model where the adaptive segmentation is replaced by the fixed-length segmentation with different group sizes in the death prediction task (the trend in the ICU admission task is similar) . We can see that the performance goes down when the group size becomes too small or too large. We infer that if the number is too small, local dependencies of events are modeled as long-term dependencies. And if the number is too large, long-term dependency is lost when the corresponding events are assigned to the same group. Besides, it’s obvious there is a performance gap between the adaptive segmentation and all other segmentation methods, which verifies our claim that the adaptive segmentation can help model long-term dependencies and is suitable for long irregular event sequences.

5.5 Important Events

top events(sorted by median)
Death ICU admission
event median event median
Blood Products 0.9965 Blood PH 0.9998
Radiologic Study: thoracic lumbar sac 0.9896 Vancomycin 0.9995
NV#2 Waveform Appear: overshoot 0.9713 Hematocrit (35-51) 0.9967
Heart Rhythm 0.9702 Edimentation rate 0.9885
Pain Location: periumbilical 0.9668 Daily Weight 0.9850
Family Communication 0.9523 Bilirubin Total 0.9834
Table 5: top important events on the death prediction task and the ICU-admission prediction task.

In this section, we analyze what events our proposed model pays more attention to. The event attention score of an event can measure how much attention the model pays to the event in the prediction。 Thus, it’s reasonable that we use the median of all the event attention scores of an event type on a specific task as the importance of the event type on this task.

Top important events on two tasks are listed in Table 5. We can see that even though our model mainly focuses on laboratory tests (such as “Heart Rhythm” and “Blood PH”) on both tasks, the specific events attracting the model on two tasks are different due to their different prediction targets. It is also perhaps surprising that owing to our data-driven approach, our model can select “Family Communication” as an important event type on the death prediction task, which may be ignored by doctors.

6 Conclusion

In this paper, we proposed a model to learn hierarchical representations of long and irregular clinical event sequences of EHR data for clinical outcome prediction. We validate the performance of our model on real clinical datasets for death and ICU admission prediction tasks. The significant improvements indicated that our model is suitable for irregular timed EHR data and can capture long-term temporal dependencies of clinical event sequences for precise clinical outcome predictions.

0.90 bibliography[References]

#1