Deep learning has demonstrated early successes in health risk prediction using electronic health records (EHR) data especially for patients with chronic and progressing conditions such as heart diseases and Parkinson’s disease (Ma et al., 2018; Choi et al., 2016; Heo et al., 2018; Baytas et al., 2017). Most existing works focus on the extraction of disease patterns by modeling the relationship between disease progression and time from longitudinal patient data. For example, Pham et al. (Pham et al., 2016) utilized RNN and a multiscale pooling layer to integrate temporal disease patterns from different time scales. Baytas et al. (Baytas et al., 2017) and Ma et al. (Ma et al., 2018) simulates progression of patients’ status by using temporal information to decay the information from historical timesteps.
Despite these successes, the aforementioned works (Ma et al., 2018; Baytas et al., 2017; Pham et al., 2016; Zheng et al., 2017) implicitly assume that disease progression is smooth in time — the longer the time is, the greater the change of health status will be. However, in reality, disease progression speed can vary significantly depending on the underlying disease stage.
Motivating example: Fig. 1 plots the variation of albumin and hypersensitive C-reactive protein (hs-CRP) of an end-stage renal disease (ESRD) patient. We can observe the patient’s health status between and is different from the rest time period. The sudden decline of albumin and fluctuation of hs-CRP indicate the patient’s status is deteriorating rapidly. We consider the patient’s health condition enters a new stage at . The disease stage here refers to a time period with consistent status progression. It is not specific to a single disease (e.g. Alzheimer’s Stage 1 or 2) but considers generally all comorbidities a patient has. For example, deteriorating and recovering are two different stages.
Challenges: Disease progression stages indicate different risk levels. However, we still have the following challenges to utilize stage information in risk prediction tasks.
How to extract disease progression stages from complex EHR data? Although important disease stage information is often unavailable in data and sometimes not even clearly defined for many diseases. Many temporal models such as LSTM have gating mechanism to control what historical information to remember or forget. However, the explicit underlying stage or change point is not clearly specified by those models.
How to leverage disease progression stage information for more accurate risk prediction? Intuitively disease progression patterns should be similar within one stage but different across stages, as shown in Fig. 1(b)
. Probably due to the absence of stage information in the data, disease stage information is largely ignored in predictive modeling of EHR data. Since the progression patterns relate to health risks(Halm et al., 1998; Plöchl et al., 1996), the model should learn to extract and select informative patterns from each disease stage for more accurate risk prediction.
Contributions: To address these challenges, we propose a new Stage-aware neural Network model (StageNet) to extract disease stage information from patient data and integrate it into risk prediction. StageNet consists of two modules, a stage-aware LSTM module and a stage-adaptive convolutional module. Our StageNet model is enabled by the following technical contributions.
Extract disease progression stage without supervision. The stage-aware LSTM module of StageNet can capture the stage variation of patients’ health conditions unsupervisedly. Specifically, we integrate inter-visit time information into LSTM cell states, which enables each dimension in the cell state to decide the storage proportion between long-term progression and short-term health status information. With such a design, StageNet can evaluate the variation of patient health conditions compared with previous stages.
Learn disease progression patterns from disease stage information. StageNet further incorporates extracted disease progression stage information into the convolution operation, which can learn progression patterns closely related to the current stage. We further re-calibrate these patterns to emphasize informative patterns for outcome prediction.
We evaluated StageNet with both health risk prediction task and patient subtyping task on real-world urgent care MIMIC-III dataset and end-stage renal disease dataset. Risk prediction results show that StageNet
consistently outperforms all state-of-the-art models on both datasets in terms of different evaluation metrics. The improvement ofStageNet is up to 12% in AUPRC compared to the best baseline model. The patient subtyping results show that StageNet performs better than baseline models in identifying discriminative patient subgroups.
2. Related Work
Disease Progression Modeling Recent years, various deep learning models have been proposed to model disease progression. One solution is to model the disease trajectory using Markov-based models. For example, Sukkar et al. (Sukkar et al., 2012)
applied Hidden Markov Model (HMM) to learn Alzheimer’s disease progression. Wanget al. (Wang et al., 2014) proposed a probabilistic model to learn a continuous-time disease progression. Liu et al. (Liu et al., 2015) developed a continuous time HMM for modeling glaucoma and Alzheimer patients. However, these works assume Markov property which might not be true in practice when existing long term dependency.
Instead of modeling state transition probability, another line of efforts, to which our approach belongs, focus on utilizing patients’ general health status progression to conduct clinical prediction tasks such as subtyping (Baytas et al., 2017; Galagali and Xu-Wilson, 2018) or risk prediction (Zheng et al., 2017; Pham et al., 2017). For example, T-LSTM (Baytas et al., 2017) incorporated the elapsed time information into the standard LSTM architecture to model status progression in the presence of time irregularities. Health-ATM (Ma et al., 2018) used a time-aware convolutional layer by integrating time stamps into the original convolutional layer to model status progression. However, none of these works model disease progression from the stage perspective. To the best of our knowledge, we are the first work to explicitly extract disease stage progression information and utilize it in risk prediction tasks.
Attention mechanism for feature re-calibration
Feature re-calibration refers to using attention mechanism to emphasize informative input features or convolutional feature maps and suppress less informative ones. Feature re-calibration mechanism has achieved success in computer vision tasks. Huet al. (Hu et al., 2018) proposed Squeeze-and-Excitation Block to explicitly model the dependencies between the channels of convolutional features by aggregating convolutional feature maps across spatial dimensions. In EHR analysis, feature re-calibration mechanism is mainly used to provide model interpretability. For example, Choi et al. proposed RETAIN to provide feature interpretability using attention over input features.
In this work, we re-calibrate disease progression patterns via extracting the progression theme at the current stage. Our model can adaptively emphasize most indicative features to help better predict patients’ health risks at different stages.
Other related work:
The underlying model of the stage-aware LSTM module in our work is related to the ordered neuron mechanism in(Shen et al., 2018), which proposes to solve NLP parsing tasks. However, unlike word sequences, the time irregularity in EHR data has clinical meanings. In this work, the inter-visit time intervals are used to help LSTM summarize stages of patients’ status progression.
3. StageNet Method
Below we define the data and task studied in this work and provide the list of notations used in StageNet in Table. 1.
Definition 0 (Patient Records).
In longitudinal EHR data, each patient can be represented as a sequence of multivariate observations , where is the number of visits of the patient. Each visit record is a concatenation of clinical features (i.e., lab tests and measurements) including multi-hot and numerical values, where is the number of medical features at each visit. To model the time irregularity of visits, we also use to represent the elapsed time between two visits, , specifically.
Problem 1 (Health Risk prediction).
Given a patient’s visit records , we define dynamic patient risk prediction as to predict the health risk at the -th timestep, which indicates the target outcome of interest (e.g., mortality, decompensation).
|;||patient record sequence; time interval sequence|
|multivariate visit record at the -th visit|
|time interval between and|
|ground truth of prediction targets at the -th visit|
|predictions at the -th visit|
|hidden state of LSTM at the -th timestep|
|cell state of LSTM at the -th timestep|
|extracted progression patterns within the current stage|
|status progression theme at the current stage|
|re-calibration weights of progression patterns|
|re-calibrated stage-adaptive progression patterns|
|stage variation at the -th timestep|
|stage variation from to|
|;||master forget gate; master input gate|
As shown in Figure 2, StageNet comprises three modules: (1) a stage-aware LSTM module, (2) a stage-adaptive convolutional module, and (3) a prediction module. StageNet takes patient EHR sequence and time interval sequence as input of the stage-aware LSTM module, and output the -th time hidden state along with stage variation factor . Here represents the variation degree at the -th visit compared to historical status – a large indicating a higher chance of entering a new stage. We feed and to the stage-adaptive convolutional module to extract patients’ status progression patterns within an observation window . The length of observation window is similar to convolutional kernel size, which decides the timescale of extracted convolutional patterns. The stage information is integrated to convolution operations to extract disease progression patterns that closely related to the current stage. We re-calibrate these patterns to emphasize informative patterns and suppress less useful ones via extracting the disease progression theme at the current stage. Last, we predict the health risk based on these improved patient representations.
3.2. Stage-aware LSTM module
LSTM background Given a sequence of patient health records and a sequence of inter-visit time intervals , our goal is to infer the stage variation of the patient’s health status (i.e., ) while constructing current health status (i.e., ). Original LSTM consists of forget gate , input gate and output gate . At the -th timestep, the original LSTM takes previous hidden state , cell state and current visit as input, and output current hidden state and cell state . However, cell state does not differentiate where the historical information is from (whether is from recent past or long history). StageNet will use two new gates to make different dimensions in the cell state indicate different time scales.
Goal The key objective is to differentiate historical information in the cell state. If we can differentiate recent and old history in , we can determine whether the change to is due to old history or recent history. Then a change mainly due to the recent history means the underlying disease stage has just changed. Hence, we can derive the disease stage change based on the change to the cell state.
Idea The idea is to make each dimension in the cell state represent patients’ status at a different time scale. Intuitively, the dimensions in cell state can be divided into the low-ranking part (the first half of ) and the high-ranking part (the second half of
). The low-ranking part contains short-term health status information that only recent visits. And high-ranking part contains patients’ long-term progression information that will last several visits or even the entire visit sequence. Note that because the low-ranking part is related to most recent visits, the update frequency of low-ranking dimensions is always higher than high-ranking dimensions. One simple way to enforce low and high ranking parts is to use two separate binary mask vectors. The low-ranking mask vector can be, while the high-ranking mask will be .
Soft masking Instead of hard binary masks, we learn two soft mask vectors:
for high-ranking part (representing old history);
for low-ranking part (representing recent history).
In particular, the ranking is dynamically determined using information from patients’ current visit and historical health status , adjusted by the elapsed time between two visits. We utilize and further extend the master forget gate and master input gate in (Shen et al., 2018) by using time interval information as follows:
where is the bias, denotes the cumulative sum, the arrow above indicates the direction of cumulative sum and denotes the concatenation operation. and correspond to the probabilistic distribution of dimensions of for high-ranking (old history) and low-ranking (recent history), respectively. Following the properties of the operation, the values in are monotonically increasing from 0 to 1 (e.g. ), and those in are monotonically decreasing from 1 to 0 (e.g. ). Hence, and can be used to serve as the soft mask for high-ranking part and low-ranking part, respectively.
Cell state update We define the new calculation for current cell state as:
where the calculation of intermediate cell state and hidden state are same with the original LSTM. Values in are used to decide which dimensions in to store long-term information about status progression (i.e. ), and values in decide which dimensions to store short-term information (i.e. ). In Eq. 2, decide which dimensions to store the overlap part between and . Besides the overlapping information, the independent (non-overlapping) information in and are stored into based on the values of and , respectively. Fig 3 shows the structure of stage-aware LSTM.
Stage progression variation Since values in decide where to store progression information, we denote as:
The value of decides how much history information is used to calculate the current . If is large, there is almost no historical state information in current cell state, which means that the patient’s current health status have changed a lot compared to history status. In other word, a large may indicate that the patient’s status may have entered a new stage.
function is non-differentiable, we use the following equation to estimate:
where and are the -th values in and .
Example illustration We use a toy example to better illustrate how and store patients’ health status at different timescales and summarize stage variation factor . Assume , and . According to Eq. 1 and 2, , and . And . indicates the variation of patients’ stage.
The is calculated as:
The fifth dimension in is used to store long-term progression information from , and the first and second dimensions are used to store short-term health status from . The third and fourth dimensions are used to store overlapping information. However, because we use activation, the actual values in and are decimals instead of 0 and 1.
As and only focus on coarse-grained control, in practice, we reduce the dimension of and to similar to (Shen et al., 2018), where is a chunk size factor. Therefore, every dimension within each -sized chunk shares the same master gates. A smaller can make the model describe patients’ status variation in more details.
3.3. Stage-adaptive convolutional module
To leverage the stage information learned from LSTM, we develop a stage-adaptive convolutional module on top of the recurrent layer to extract and re-calibrate patient health progression patterns for risk prediction.
Progression patterns of patients’ health status are critical in predicting patients’ risks (Yeh et al., 1984; Halm et al., 1998). Since there are many medical research indicate that these progression patterns are often similar within one stage, but vary across different stages (Halm et al., 1998; Plöchl et al., 1996; Yeh et al., 1984), we expect StageNet can extract patterns that are closely related to patients’ current stage using convolutional filters. We also design the model to adaptively select the most informative patterns for risk prediction at the current stage. We achieve this through three steps: 1) learning stage progression patterns 2) extracting progression theme at the current stage and 3) re-calibrating progression patterns. The structure of stage-adaptive convolutional module is shown in Fig. 4.
Learning stage progression patterns: We further extract the progression patterns of at the current stage by using stage-weighted convolution operators. We modify the original CNN by integrating disease stage information into the convolution operation.
Mathematically, at the -th time, we calculate the distance between stages of historical visits within the observation window and the stage of current visit as:
where is the length of observation window. The length of observation window is similar to convolutional kernel size, which decides the timescale of extracted progression patterns. The values in are monotonically decreasing from 1 to 0. We denote the -th value in as . A large indicates that the stage of is far from the current stage of .
In stage-weighted convolution operation at the -th timestep, the convolutional module takes concatenated historical hidden states sequence of LSTM (i.e. ) as input. Different from the original convolutional layer, the weights of input variables are re-weighted by their stage distance in each convolution computation as:
where is the convolution operation, , . is a 1D convolution kernel representing a single channel of that acts on the corresponding channel of and is the kernel size. We use multiple kernels to generate extract different patterns, and the number of kernels is . We concatenate the output of kernels to get the final convolution output as . We set to make each kernel can extract progression patterns that represent the whole stage, so that the final dimension is .
In Eq. 7, The weights of patients’ historical health status are adjusted according to the distance of stages in order to extract patterns that are closely related to the current stage. If the stage of a historical status is far from the current stage, it will have a lower weight in the stage-weighted convolution operation and vice versa.
Extracting progression theme at the current stage: Since the output of the convolution operation is the concatenation of multiple patterns extracted by different kernels. The importance of these patterns may vary and depends on patients’ status at different stages. In order to select the most informative patterns at the current stage, we should provide the model with a global view of patients’ status at the current stage as:
The global status representation at the current stage is the weighted average of hidden states within the observation window. can be regarded as the progression theme of the current stage. The importance of different temporal patterns will be calculated based on this theme.
Re-calibrating progression patterns: After obtaining the status progression theme at the current stage, we map the reprensentation to an importance vector , where the -th value in indicates the importance of the -th extracted temporal pattern in . is calculated as:
where refers to the sigmoid functon,
is the ReLU function,and . We use two fully-connected (FC) layers to map the progression theme to , i.e. a dimension-ality-reduction layer with ReLU activation to compress the representation while capturing the non-linearity, and a dimensionality-increasing layer to rescale the output to the original dimension of . We use sigmoid activation to generate the importance weights between 0 and 1.
Finally, the features in is re-calibrated using as:
The re-calibration mechanism can be regarded as a channel-wise attention mechanism like (Hu et al., 2018). However, the attention weights (i.e. the importance of each pattern) is calculated by patients’ status progression theme at the current stage instead of using global average pooling to generate channel-wise statistics in (Hu et al., 2018) or calculating alignment between historical states.
3.4. Prediction module
The prediction layer takes the output of stage-adaptive convolutional module as input, and outputs a binary label
, which indicates the patient’s current health risk. Note that, we include residue connections between the convolutional module and the output layer. In order to achieve this, we set. We compute as:
where . We choose the cross-entropy function to calculate the loss for each patient as:
We use the Adam algorithm (Kingma and Ba, 2014) for optimization. We summarize StageNet algorithm below.
We evaluated StageNet model by comparing against other baselines on public dataset MIMIC-III and ESRD (i.e., end-stage renal disease) dataset. The code is provided in 111https://github.com/v1xerunt/StageNet.
4.1. Dataset description
We use the following data to evaluate our model.
MIMIC-III Dataset We use Intensive Care Unit (ICU) data from the publicly available Medical Information Mart for Intensive Care (MIMIC-III) database (Johnson et al., 2016). Following the work (Harutyunyan et al., 2017)
, the cohort of 33, 678 unique patients with a total of 2,202,114 samples (i.e., records) is used. The raw data includes 17 physiologic variables at each visit, which is transformed into a 76-dimensional vector including numerical and one-hot encoded categorical clinical features.
End-Stage Renal Disease (ESRD) Dataset We perform the mortality risk prediction on an end-stage renal disease dataset. There are many people suffered from ESRD in the world (Tangri et al., 2011; Isakova et al., 2011). They face severe life threat and need lifelong treatment with periodic visits to the hospitals for multifarious tests (e.g., blood routine examination). The whole procedure needs a dynamic risk prediction system to help patients prevent adverse outcomes, based on the medical records collected along with the visits. The cleaned dataset consists of 656 patients and 13,091 visit records and the percentage of positive labels is 17.5%. The raw data includes 17 numeric physiologic variables at each timestep. During and after data collection and analysis, the authors could not identify individual participants as patients’ names were replaced by ID. We use patients’ previous records to fill the missing data in order to prevent the leakage of future information.
|Model||AUPRC||AUROC||min(Re, P+)||AUPRC||AUROC||min(Re, P+)|
|Baseline||LSTM||0.280 (0.003)||0.897 (0.002)||0.324 (0.003)||0.270 (0.029)||0.805 (0.026)||0.318 (0.015)|
|ON-LSTM||0.304 (0.002)||0.895 (0.003)||0.343 (0.004)||0.291 (0.021)||0.810 (0.021)||0.333 (0.034)|
|T-LSTM||0.282 (0.004)||0.895 (0.002)||0.322 (0.005)||0.276 (0.027)||0.812 (0.026)||0.331 (0.031)|
|Decay-LSTM||0.294 (0.002)||0.893 (0.003)||0.330 (0.004)||0.289 (0.020)||0.808 (0.022)||0.328 (0.021)|
|0.291 (0.002)||0.897 (0.003)||0.325 (0.003)||0.287 (0.021)||0.810 (0.039)||0.331 (0.025)|
|Reduced||StageNet-I||0.313 (0.003)||0.899 (0.003)||0.360 (0.002)||0.296 (0.014)||0.814 (0.031)||0.333 (0.018)|
|Model||StageNet-II||0.311 (0.003)||0.897 (0.002)||0.358 (0.003)||0.302 (0.029)||0.812 (0.027)||0.334 (0.017)|
|Proposed||StageNet||0.323 (0.002)||0.903 (0.002)||0.372 (0.003)||0.327 (0.022)||0.821 (0.024)||0.352 (0.019)|
We evaluated StageNet against the following baselines, which share some of the similar insights with StageNet. It is worth noting that there are lots of state-of-the-art clinical prediction models which utilize attention mechanism to extract long-term dependencies in patients’ historical visits (Lee et al., 2018; Choi et al., 2016; Song et al., 2018). However, their contribution is orthogonal to ours. We focus on capturing and utilizing stage information of patients’ health status in EHR data. Our model StageNet can be easily combined with attention mechanism.
LSTM (Gers et al., 1999) The visit input at the -th timestep is fed into the LSTM model. Then it directly output the prediction results based on the hidden state vector .
ON-LSTM (Shen et al., 2018) uses LSTM to model tree-like structures for natural language sequences by separately allocating hidden state dimensions with long and short-term information.
Decay-LSTM (Zheng et al., 2017) uses feature-level time intervals to enable memory decay similar to T-LSTM. We adopt the decay mechanism on the input gate of LSTM. Decay-LSTM requires to input time intervals of each feature, and we also input this information to all the other models without loss of fairness.
(Ma et al., 2018) uses irregular time intervals to decay the information from historical timestep via a hybrid CRNN structure. The original model utilizes the target-aware attention mechanism to achieve disease prediction. Since our task doesn’t have a specific target embedding to guide the attention, we remove the target-aware attention mechanism from Health-ATM.
We also compare StageNet against its reduced models:
StageNet-I consists of regular LSTM and the stage-adaptive convolutional module. The weighted convolution operation is also replaced by regular convolution operation. We use the average of within the observation window to calculate .
StageNet-II only has stage-aware LSTM. The visit input at the -th timestep is fed into the stage-aware LSTM model. Then it directly outputs the prediction results based on the hidden state vector .
4.3. Health Risk Prediction
In this section, we report experimental results for following supervised tasks on two datasets.
Decompensation risk prediction
We perform the physiologic decompensation prediction task on MIMIC-III dataset. This task involves the detection of patients who are physiologically decompensating, which means conditions are deteriorating rapidly. Detection of decompensation is closely related to problems like condition monitoring and sepsis detection that have received significant attention from the machine learning community. The task is formulated as a binary classification task for predicting whether the patient’s date of death (DOD) falls within the next 24 hours of the current time point. These labels are assigned to each hour, starting at four hours after admission to the ICU and ending when the patient dies or is discharged.
We truncate the length of samples to a reasonable limit (i.e. 400). We fix a test set of
of patients, and divide the rest of the dataset into the training set and validation set with a proportion of 85%:15%. We fix the best model on the validation set and report the performance in the test set. We also report the standard deviation of the performance measures by bootstrapping the results on the test set for 10,000 times.
Mortality risk prediction We perform the mortality risk prediction task on the ESRD dataset. Similar to the decompensation task, the mortality risk prediction task is formulated as a binary classification task to predict whether the patient will die within 12 months, and the predictions are made at each timestep. We evaluate the models with 10-fold cross-validation strategy and report the average performance and standard deviations.
Implementation Detailset al., 2017) and trained on a server equipped with an Intel Xeon E5-2620 Octa-Core CPU, 256GB Memory and a Titan V GPU.
, we assess performance using area under the receiver operating characteristic curve (AUROC), area under the precision-recall curve (AUPRC), and the minimum of precision and sensitivity Min(Re,P+). The Min(Re,P+) is calculated as the maximum of min(recall, precision) on the precision-recall curve.
Table 2 compares the performance of all models on both datasets. StageNet consistently outperforms all state-of-the-art models on both datasets. On MIMIC-III dataset, StageNet achieves 10% higher AUPRC and min(Re,P+) compared to the best baseline model ON-LSTM and Health-ATM. On the ESRD dataset, StageNet achieves 12% higher AUPRC and 6% higher min(Re, P+) compared to the best baseline model ON-LSTM.
The reduced models StageNet-I and StageNet-II still outperform all state-of-the-art models in most cases on both datasets. It proves that extracting higher-level temporal variation features and summarizing stage information are both helpful for predicting patients’ health risks. Among all baselines, ON-LSTM and Health-ATM achieve better performance in most cases due to handling the aforementioned challenges to some extent.
Health status stability vs. Cause of death In order to understand how different clinical events affect disease progression, we further analyze the health status stability of patients with different causes of death in the ESRD dataset. At each timestep, StageNet will output a scalar , which indicates the variation of patients’ stage. A large indicates the patient’s current health status has changed a lot compared to history status (i.e. enter a new stage). We compute the average stage variation (i.e. the average of ) for each patient. A patient with stable health status will have a low . The average stage variation of patients with different causes of death is shown in Fig. 5.
The result shows that patients who died of peritonitis, cardiovascular (CVD) and cerebrovascular (CVE) have the highest , which means that their health status are very unstable. These diseases are acute diseases and have high mortality rate (Kannel et al., 1987; Fried et al., 1996; ESTANOL and M. MARIN, 1975). The health status of patients who have these diseases tends to deteriorate rapidly in a short period of time, which explains why our model believes these patients’ status are highly unstable. In contrast, patients with cancer have the most stable status compared to other patients, since their health status often deteriorates more chronically and have longer survival time compared to the patients with acute cardiac diseases (Derogatis et al., 1979; Prentice and Gloeckler, 1978). Clinicians should pay more attention to patients with heart disease or peritonitis history in order to take timely interventions.
Health status stability vs. Health risk. The stability of patients’ health status is an important indicator to evaluate patients’ health risk (Plöchl et al., 1996; Halm et al., 1998; Yeh et al., 1984). At each visit, our model will evaluate the patient’s current health status and output a health risk score. We divide patients’ visits into three groups according to predicted health risk: Low risk (risk score ¡= 0.4), Medium risk (0.4 ¡ risk score ¡= 0.7) and High risk (risk score ¿= 0.7). We compute the average stage variation of each group to explore Table 3.
|[0.7pt] Risk level||Low risk||Medium risk||High risk|
|Avg. stage var.||0.354 (0.003)||0.393 (0.003)||0.437 (0.005)|
The results show that the health status of patients with high risk is more unstable, and patients with low risk have the most stable status. This is consistent with conclusions in medical researches that clinicians use physiologic stability index to evaluate patients’ health risk (i.e. patients with unstable status have higher mortality risk) (Yeh et al., 1984; Halm et al., 1998).
4.3.3. Case study
To explore how our model extracts the stage variation information of patients’ health status and further utilize it to make predictions, we present a specific case study of a patient in the test set. As shown in Figure 6, the purple line indicates the stage variation and the red line is the predicted mortality risk of the patient.
For this specific patient, there is a distinct rising period in the two lines. Before May 2011, the patient’s risk remains within a relatively low range. At the same time, also keeps a low value, which indicates that the model believes that the patient’s health status is stable during this time. However, around 2012, the stage variation reaches the peak and the predicted risk rises rapidly, which means the patient has entered a high-risk stage. The risk remains a high value until the end. According to the clinical notes, the patient encountered Acute Myocardial Infarction (AMI) around 2012, and eventually died because of heart disease. If physicians were reminded when the model found the patient’s status was changing drastically, the adverse outcome may be prevented or delayed by taking early interventions.
4.4. Patient Subtyping
In this experiment, we conduct patient subtyping on the ESRD dataset to investigate the expressive power of the patient representation learned from the StageNet. ESRD is a chronic disease, and patients need to receive continuous medical treatment for years or even decades. Patients may face various risk factors such as infection, heart disease or cancer (Weiner et al., 2004). Patient subtyping task is to seek patient groups with similar disease progression pathways (Baytas et al., 2017). Identifying patients subtypes can help clinicians develop targeted treatment plans and prevent adverse outcomes.
We use the learned
at the last timestep in the previous risk prediction task as representations for patients’ health status. For baseline models, we use the representations before the output layer. The learned representations are used to cluster the patients by the k-means algorithm(Hartigan and Wong, 1979). Since we do not know the ground truth groups of subtypes, we conduct several statistical analysis to assess the subtyping performance. Moreover, we use Calinski-Harabasz score (Caliński and Harabasz, 1974) (C-H score for abbreviation) to evaluate the subtyping performance quantitatively. A higher C-H score relates to a model with better defined clusters. The C-H score is calculated as:
where is sample size, is the number of clusters, is the covariance matrix between clusters, is the covariance matrix within clusters, and is the trace of matrix.
We fix a test set of 20% of patients, and the other 80% of patients are used for training. We use the same hyper-parameter as the risk prediction task. We tried several values for the k-means algorithm. We can observe four main clusters. Therefore we report the clustering Calinski-Harabasz score and average groud truth mortality risk in each cluster when . The results are shown in Table 4.
|[0.7pt]||C-H score||Cluster I||Cluster II||Cluster III||Cluster IV|
The results show StageNet achieves over 58% higher Calinski-Harabasz score compared to the best baseline model . The average risk in each cluster shows that StageNet divide patients into two high-risk groups (Cluster III and IV) and two low-risk groups (Cluster I and II). On-LSTM, Decay-LSTM and divide patients into one high-risk group (Cluster IV), two medium-risk groups (Cluster II and III) and one low-risk group (Cluster I). However, T-LSTM only identifies one high-risk group (Cluster IV) and three low-risk groups (Cluster I, II and III). Risk scores of different clusters for our baseline models are increasing evenly, which means the learned representations distribute evenly in latent space and not form meaningful clusters and thus results in low C-H score.
4.4.1. High-risk patient subtypes
In order to interpret the clustering results in terms of subtyping, we compared the medium-risk and high-risk clusters with low-risk clusters using T-test to identify discriminative features (p-value ¡ 0.05). We find that there are 5-7 significant features in each cluster, and we report the top 5 significant features ranked by p-value in Table5.
|Cluster II||Albumin||C-rp||Blood chlorine|
|Cluster III||Albumin||C-rp||Blood chlorine|
|Cluster IV||Albumin||Blood chlorine||Serum creatinine|
|Cluster II||DBP||Albumin||Blood chlorine|
|Cluster III||Albumin||DBP||Blood chlorine|
|Cluster IV||DBP||Blood chlorine||Albumin|
|Cluster II||Blood chlorine||Blood potassium||Albumin|
|Cluster III||Blood chlorine||Albumin||Blood potassium|
|Cluster IV||Blood chlorine||Blood potassium||Albumin|
|Cluster III||Albumin||Serum creatinine||Blood urea|
|Cluster IV||C-rp||DBP||Blood potassium|
The results show that almost all baseline models choose albumin, C-reactive protein, glucose, chlorine, diastolic blood pressure and blood urea to distinguish between low-risk patients and high-risk patients. However, significant features in different high-risk and medium-risk clusters are almost the same for all baseline models, which indicates that these models are unable to further distinguish subtypes among high-risk patients and therefore have worse clustering performance.
In contrast, StageNet clearly divides patients into two high-risk groups and also identifies more discriminative features. In Cluster III, albumin, blood urea and appetite are important indicators related to patients’ nutritional status (Gama-Axelsson et al., 2012; Patel et al., 2013; Di Iorio et al., 2017). These biomarkers can reflect patients’ health status from a long-term perspective. While in Cluster IV, blood potassium and diastolic blood pressure are important indicators for heart diseases such as heart failure (Kannel et al., 1987; Lee et al., 2009). Patients with high white blood cell count and C-reactive protein are likely to have severe infections (Jialal and Devaraj, 2001; Franz et al., 1999). We also notice that StageNet identifies hemoglobin as a significant feature in Cluster IV, which has never been identified by other baseline models. According to medical research, the constant reducing of hemoglobin is a key factor denoting the occurrence of acute GI bleeding (Tomizawa et al., 2014), which may cause sudden death.
We also explore the distribution of cause of death in Cluster III and Cluster IV of StageNet, the result is shown in Fig 7. In Cluster III, the main causes of death are cancer and gastrointestinal disease (GI disease), which are mainly considered as more chronic disease (Derogatis et al., 1979; Prentice and Gloeckler, 1978). While in Cluster IV, the main causes of death are cardiovascular, peritonitis and cerebrovascular, which are acute symptoms (Kannel et al., 1987; Fried et al., 1996; ESTANOL and M. MARIN, 1975). This is consistent with medical researches and our previous experiment results. However, baseline models failed to identify these high-risk subtypes. For example, the cause of death distribution for ON-LSTM is shown in Fig 8. There is no significant difference in the cause of death distribution between different clusters.
4.4.2. Low-risk patient subtypes
In Table 4, StageNet also divides low-risk patients into two subtypes (Cluster I and II). To identify the difference between the two clusters, we report the discriminative features in two clusters using T-test and the mean value of these features in each cluster. The result is shown in Table 6.
|[0.7pt] Feature||Cluster I Mean||Cluster II Mean|
|Serum creatinine||988.1 (172.4)||758.3 (112.8)|
|Glucose||6.9 (1.4)||5.4 (1.0)|
|Albumin||36.7 (3.2)||40.9 (2.5)|
|Blood chlorine||100.1 (2.9)||97.5 (2.1)|
We can observe that patients in Cluster I have higher serum creatinine, glucose, blood chlorine and lower albumin compared to patients in Cluster II. As discussed above, albumin and serum creatinine are important indicators for patients’ nutritional status, and indicate the severity of patients’ ESRD progression (Di Iorio et al., 2017; Patel et al., 2013). Kidney damage may cause high blood chlorine level and high glucose level may indicate patients have diabetes (Batlle et al., 1981). In conclusion, patients in Cluster I have higher potential health risk compared to patients in Cluster II.
4.4.3. Health status stability vs. Patient subtypes
In order to understand how patients’ disease progression stage information help StageNet to identify patient subtypes, we calculate the average stage variation of patients in each cluster. The results are shown in Table 7
|Avg. stage var.||0.297 (0.044)||0.295 (0.039)||0.350 (0.038)||0.409 (0.031)|
The results show that the variation of patients’ health status is positively correlated with patients’ health risk in each cluster, which have been proved in Table 3. We notice that patients in Cluster IV have the most unstable status, since they have more acute symptoms. Though patients in Cluster III still have high health risk, they have lower stage variation compared to Cluster IV, because their disease progressions are more chronic. The results in Table 7 are consistent with the observations and medical findings we discussed above. Compared to baseline models, StageNet can learn discriminative patient representations from EHR sequences by extracting and utilizing the disease progression stage information.
In this work, we propose a stage-aware neural network model, StageNet, to conduct health risk prediction using patients’ stage variation of health status. StageNet consists of a stage-aware LSTM module and a stage-adaptive convolutional module. StageNet can extract the stage of patients’ health status at each visit unsupervisedly, then leverage and re-calibrate stage-related variation patterns into risk prediction. Supervised health risk prediction experiments on two real-world datasets demonstrate that StageNet consistently outperforms state-of-the-art methods by better capturing inherent disease progression stage information in EHR data. Compared to the best baseline model, StageNet achieves 10% higher AUPRC and min(Re,P+) on public MIMIC-III dataset, and 12% higher AURPC and 6% higher min(Re,P+) on ESRD dataset. The patient subtyping experiment shows that StageNet performs better than baseline models to learn discriminative representations by extracting and utilizing the stage information. In clinical practice, we hope our model can help physicians identify the patients with unstable health status to prevent or delay the adverse outcome.
This work is part supported by National Natural Science Foundation of China (No. 91546203), National Science Foundation award IIS-1418511, CCF-1533768 and IIS-1838042, the National Institute of Health award NIH R01 1R01NS107291-01 and R56HL138415.
Appendix A Dataset details
The basic statistics of two dataset are shown in Table 8.
|# patients died||261|
|# visit with positive label ()||2287|
|# visit with negative label ()||10804|
|# average time interval between visits||3.4 months|
|# ICU stays||41,902|
|# visit with positive label ()||45,364|
|# visit with negative label ()||2,156,750|
Appendix B Implementation details
For hyper-parameter settings of each baseline model, our principle is as follows: For some hyper-parameter, we will use the recommended setting if it is available in the original paper. Otherwise, we determine its value by grid search on the validation set.
LSTM/T-LSTM/Decay-LSTM. The hidden units of LSTM cell are set to 64 / 128 for ESRD / MIMIC-III dataset respectively.
ON-LSTM. The hidden units of LSTM cell are set to 72 / 384. The chunk size factor is set to 36 / 128 for ESRD / MIMIC-III dataset respectively.
. The hidden units of LSTM cell are set to 64 / 128. The number of convolutional filters is set to 32 / 64 and the size of filters is set to 3 / 5 for ESRD / MIMIC-III dataset respectively.
StageNet. The length of observation window is set to 10. The hidden units of LSTM cell are set to 72 / 384. The chunk size factor is set to 36 / 128.
Additionally, we use dropout (Srivastava et al., 2014) before the output layer and dropconnect (Wan et al., 2013) in the LSTM layer. The dropout rate is set to 0.5 / 0.3 for ESRD / MIMIC-III dataset respectively. We train each model for 50 epochs on MIMIC-III dataset and 200 epochs on the ESRD dataset. The learning rate is set to 0.001.
- The pathogenesis of hyperchloremic metabolic acidosis associated with kidney transplantation. The American journal of medicine 70 (4), pp. 786–796. Cited by: §4.4.2.
- Patient subtyping via time-aware lstm networks. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, pp. 65–74. Cited by: §1, §1, §2, 3rd item, §4.4.
A dendrite method for cluster analysis. Communications in Statistics-theory and Methods 3 (1), pp. 1–27. Cited by: §4.4.
- Retain: an interpretable predictive model for healthcare using reverse time attention mechanism. In Advances in Neural Information Processing Systems, pp. 3504–3512. Cited by: §1, §4.2.
- Psychological coping mechanisms and survival time in metastatic breast cancer. Jama 242 (14), pp. 1504–1508. Cited by: §4.3.2, §4.4.1.
- Nutritional therapy reduces protein carbamylation through urea lowering in chronic kidney disease. Nephrology Dialysis Transplantation 33 (5), pp. 804–813. Cited by: §4.4.1, §4.4.2.
- Cardiac arrhythmias and sudden death in subarachnoid hemorrhage. Stroke 6 (4), pp. 382–386. Cited by: §4.3.2, §4.4.1.
- Comparison of procalcitonin with interleukin 8, c-reactive protein and differential white blood cell count for the early diagnosis of bacterial infections in newborn infants. The Pediatric infectious disease journal 18 (8), pp. 666–671. Cited by: §4.4.1.
- Peritonitis influences mortality in peritoneal dialysis patients.. Journal of the American Society of Nephrology 7 (10), pp. 2176–2182. Cited by: §4.3.2, §4.4.1.
- Patient subtyping with disease progression and irregular observation trajectories. arXiv preprint arXiv:1810.09043. Cited by: §2.
- Serum albumin as predictor of nutritional status in patients with esrd. Clinical Journal of the American Society of Nephrology 7 (9), pp. 1446–1453. Cited by: §4.4.1.
- Learning to forget: continual prediction with lstm. Cited by: 1st item.
- Time to clinical stability in patients hospitalized with community-acquired pneumonia: implications for practice guidelines. Jama 279 (18), pp. 1452–1457. Cited by: item 2, §3.3, §4.3.2, §4.3.2.
- Algorithm as 136: a k-means clustering algorithm. Journal of the Royal Statistical Society. Series C (Applied Statistics) 28 (1), pp. 100–108. Cited by: §4.4.
- Multitask learning and benchmarking with clinical time series data. arXiv preprint arXiv:1703.07771. Cited by: 1st item, §4.3.
- Uncertainty-aware attention for reliable interpretation and prediction. In Advances in Neural Information Processing Systems, pp. 909–918. Cited by: §1.
Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 7132–7141. Cited by: §2, §3.3.
- Fibroblast growth factor 23 and risks of mortality and end-stage renal disease in patients with chronic kidney disease. Jama 305 (23), pp. 2432–2439. Cited by: 2nd item.
- Inflammation and atherosclerosis: the value of the high-sensitivity c-reactive protein assay as a risk marker. Pathology Patterns Reviews 116 (suppl_1), pp. S108–S115. Cited by: §4.4.1.
- MIMIC-iii, a freely accessible critical care database. Scientific data 3, pp. 160035. Cited by: 1st item.
- Heart rate and cardiovascular mortality: the framingham study. American heart journal 113 (6), pp. 1489–1494. Cited by: §4.3.2, §4.4.1, §4.4.1.
- Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §3.4.
- Relation of disease etiology and risk factors to heart failure with preserved or reduced ejection fraction: insights from the national heart, lung, and blood institute’s framingham heart study. Circulation 119 (24), pp. 3070. Cited by: §4.4.1.
- Diagnosis prediction via medical context attention networks using deep generative modeling. In 2018 IEEE International Conference on Data Mining (ICDM), pp. 1104–1109. Cited by: §4.2.
- Efficient learning of continuous-time hidden markov models for disease progression. In Advances in Neural Information Processing Systems 28, C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett (Eds.), pp. 3600–3608. External Links: Cited by: §2.
- Health-atm: a deep architecture for multifaceted patient health record representation and risk prediction. In Proceedings of the 2018 SIAM International Conference on Data Mining, pp. 261–269. Cited by: §1, §1, §2, 5th item, §4.3.
- Automatic differentiation in pytorch. Cited by: §4.3.
- Serum creatinine as a marker of muscle mass in chronic kidney disease: results of a cross-sectional study and review of literature. Journal of cachexia, sarcopenia and muscle 4 (1), pp. 19–29. Cited by: §4.4.1, §4.4.2.
- Deepcare: a deep dynamic memory model for predictive medicine. In Pacific-Asia Conference on Knowledge Discovery and Data Mining, pp. 30–41. Cited by: §1, §1.
- Predicting healthcare trajectories from medical records: a deep learning approach. Journal of biomedical informatics 69, pp. 218–229. Cited by: §2.
- Nutritional status, icu duration and icu mortality in lung transplant recipients. Intensive care medicine 22 (11), pp. 1179–1185. Cited by: item 2, §3.3, §4.3.2.
- Regression analysis of grouped survival data with application to breast cancer data. Biometrics, pp. 57–67. Cited by: §4.3.2, §4.4.1.
Ordered neurons: integrating tree structures into recurrent neural networks. arXiv preprint arXiv:1810.09536. Cited by: §2, §3.2, §3.2, 2nd item.
Attend and diagnose: clinical time series analysis using attention models. In
Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §4.2, §4.3.
- Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research 15 (1), pp. 1929–1958. Cited by: Appendix B.
- Disease progression modeling using hidden markov models. In 2012 Annual International Conference of the IEEE Engineering in Medicine and Biology Society, pp. 2845–2848. Cited by: §2.
- Determining factors that predict technique survival on peritoneal dialysis: application of regression and artificial neural network methods. Nephron Clinical Practice 118 (2), pp. c93–c100. Cited by: 2nd item.
- Reduced hemoglobin and increased c-reactive protein are associated with upper gastrointestinal bleeding. World Journal of Gastroenterology 20 (5), pp. 1311–1317. Cited by: §4.4.1.
- Regularization of neural networks using dropconnect. In International conference on machine learning, pp. 1058–1066. Cited by: Appendix B.
- Unsupervised learning of disease progression models. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 85–94. Cited by: §2.
- Chronic kidney disease as a risk factor for cardiovascular disease and all-cause mortality: a pooled analysis of community-based studies. Journal of the American Society of Nephrology 15 (5), pp. 1307–1315. Cited by: §4.4.
- Validation of a physiologic stability index for use in critically ill infants and children. Pediatric research 18 (5), pp. 445. Cited by: §3.3, §4.3.2, §4.3.2.
- Capturing feature-level irregularity in disease progression modeling. In Proceedings of the 2017 ACM on Conference on Information and Knowledge Management, pp. 1579–1588. Cited by: §1, §2, 4th item.