mime
MiME Repository
view repo
Effective modeling of electronic health records (EHR) is rapidly becoming an important topic in both academia and industry. A recent study showed that utilizing the graphical structure underlying EHR data (e.g. relationship between diagnoses and treatments) improves the performance of prediction tasks such as heart failure diagnosis prediction. However, EHR data do not always contain complete structure information. Moreover, when it comes to claims data, structure information is completely unavailable to begin with. Under such circumstances, can we still do better than just treating EHR data as a flat-structured bag-of-features? In this paper, we study the possibility of utilizing the implicit structure of EHR by using the Transformer for prediction tasks on EHR data. Specifically, we argue that the Transformer is a suitable model to learn the hidden EHR structure, and propose the Graph Convolutional Transformer, which uses data statistics to guide the structure learning process. Our model empirically demonstrated superior prediction performance to previous approaches on both synthetic data and publicly available EHR data on encounter-based prediction tasks such as graph reconstruction and readmission prediction, indicating that it can serve as an effective general-purpose representation learning algorithm for EHR data.
READ FULL TEXT VIEW PDFMiME Repository
Graph Convolutional Transformer implemented in pytorch
Large medical records collected by electronic healthcare records (EHR) systems in healthcare organizations enabled deep learning methods to show impressive performance in diverse tasks such as predicting diagnosis
(Lipton et al., 2015; Choi et al., 2016a; Rajkomar et al., 2018), learning medical concept representations (Che et al., 2015; Choi et al., 2016b, c; Miotto et al., 2016), and making interpretable predictions (Choi et al., 2016d; Ma et al., 2017). As diverse as they are, one thing shared by all tasks is the fact that, under the hood, some form of neural network is processing EHR data to learn useful patterns from them. To successfully perform any EHR-related task, it is essential to learn effective representations of various EHR features: diagnosis codes, lab values, encounters, and even patients themselves.
EHR data are typically stored in a relational database that can be represented as a hierarchical graph depicted in Figure 1. The common approach for processing EHR data with neural networks has been to treat each encounter as an unordered set of features, or in other words, a bag of features. However, the bag of features approach completely disregards the graphical structure that reflects the physician’s decision process. For example, if we treat the encounter in Figure 1 as a bag of features, we will lose the information that Benzonatate was ordered because of Cough, not because of Abdominal pain.
Recently, motivated by this EHR structure, Choi et al. (2018) proposed MiME, a model architecture that reflects EHR’s encounter structure, specifically the relationships between the diagnosis and its treatment. MiME outperformed various bag of features approaches in prediction tasks such as heart failure diagnosis prediction. Their study, however, naturally raises the question: when the EHR data do not contain structure information (the red edges in Figure 1), can we still do better than bag of features in learning the representation of the data for various prediction tasks? This question emerges in many occasions, since EHR data do not always contain the entire structure information. For example, some dataset might describe which treatment lead to measuring certain lab values, but might not describe the reason diagnosis for ordering that treatment. Moreover, when it comes to claims data, such structure information is completely unavailable to begin with.
To address this question, we study the possibility of using the Transformer (Vaswani et al., 2017) to utilize the unknown encounter structure for various prediction tasks when the structure information is unavailable. Specifically, we describe the graphical nature of encounter records, and argue that the Transformer is a reasonable model to discover implicit encounter structure. Then we propose the Graph Convolutional Transformer (GCT) to more effectively utilize the characteristics of EHR data while performing diverse prediction tasks. We test the Transformer and GCT on both synthetic data and real-world EHR records for encounter-based prediction tasks such as graph reconstruction and readmission prediction. In all tasks, GCT consistently outperformed baseline models, showing its potential to serve as an effective general-purpose representation learning algorithm for EHR data.
Although there are recent works on medical concept embedding, focusing on patients Che et al. (2015); Miotto et al. (2016); Suresh et al. (2017); Nguyen et al. (2018), visits Choi et al. (2016c), or codes Tran et al. (2015); Choi et al. (2017), the graphical nature of EHR has not been fully explored yet. Choi et al. (2018) proposed MiME, which derives the visit representation in a bottom-up fashion according to the encounter structure. For example in Figure 1
, MiME first combines the embedding vectors of lab results with the
Cardiac EKG embedding, which in turn is combined with both the Abdominal Pain embedding and the Chest Pain embedding. Then all diagnosis embeddings are pooled together to derive the final visit embedding. By outperforming various bag-of-features models in heart failure prediction and general disease prediction, MiME demonstrated the usefulness of the structure information of encounter records.The Transformer Vaswani et al. (2017)
was proposed for natural language processing, specifically machine translation. It uses a novel method to process sequence data using only attention
Bahdanau et al. (2014), and is recently showing impressive performance in other tasks such as word representation learning Devlin et al. (2018). Graph (convolutional) networks encompass various neural network methods to handle graphs such as molecular structures, social networks, or physical experiments. Kipf and Welling (2016); Hamilton et al. (2017); Battaglia et al. (2018); Xu et al. (2019). In essence, many graph networks can be described as different ways to aggregate a given node’s neighbor information, combine it with the given node, and derive the node’s latent representation Xu et al. (2019).Some recent works focused on the connection between the Transformer’s self-attention and graph networks Battaglia et al. (2018). Graph Attention Networks Veličković et al. (2018) applied self-attention on top of the adjacency matrix to learn non-static edge weights, and Wang et al. (2018) used self-attention to capture non-local dependencies in images. Although our work also relies on self-attention, our interest lies in whether the Transformer can be an effective tool to capture the underlying graphical structure of EHR data even when the structure information is missing, thus improving encounter-based prediction tasks. In the next section, we first describe the graphical nature of EHR encounter data, then show that the Transformer is a reasonable algorithm for learning the hidden graphical structure of encounter records.
As depicted in Figure 1, the -th visit starts with the visit node at the top. Beneath the visit node are diagnosis nodes , which in turn lead to ordering a set of treatments , where respectively denote the number of diagnosis and treatment codes in . Some treatments produce lab results , which may be associated with continuous values (e.g. blood pressure) or binary values (e.g. positive/negative allergic reaction). Since we focus on a single encounter in this study, we omit the time index throughout the paper.
If we assume all features , , ^{1}^{1}1If we bucketize the continuous values associated with , we can treat as a discrete feature like , . can be represented in the same latent space, then we can view an encounter as a graph consisting of nodes with an adjacency matrix that describes the connections between the nodes. We use as the collective term to refer to any of , , and for the rest of the paper. Given and , we can use graph networks or MiME^{2}^{2}2
MiME is in fact, a special form of graph networks with residual connections.
to derive the visit representation and use it for downstream tasks such as heart failure prediction. However, if we do not have the structural information , which is the case in many EHR data and claims data, we typically use feed-forward networks to derive , which is essentially summing all node representations ’s and projecting it to some latent space.Even without the structure information , it is unreasonable to treat as a bag of nodes , because obviously physicians must have made some decisions when making diagnosis and ordering treatments. The question is how to utilize the underlying structure without explicit . One way to view this problem is to assume that all nodes in are implicitly fully-connected, and try to figure out which connections are stronger than the other as depicted in Figure 2. In this work, as discussed in section 2, we use Transformer to learn the underlying encounter structure. To elaborate, we draw a comparison between two cases:
[leftmargin=5.5mm]
Case A: We know , hence we can use Graph Convolutional Networks (GCN). In this work, we use multiple hidden layers between each convolution, motivated by Xu et al. (2019).
(1) |
where , is the diagonal node degree matrix^{3}^{3}3Xu et al. (2019) does not use the normalizer to improve model expressiveness on multi-set graphs, but we include to make the comparison with Transformer clearer. of , and are the node embeddings and the trainable parameters of the -th convolution respectively. MLP
is a multi-layer perceptron of the
-th convolution with its own trainable parameters.Case B: We do not know , hence we use Transformer, specifically the encoder with a single-head attention, which can be formulated as
(2) |
where , , , and is the column size of . , and are trainable parameters of the -th Transformer block^{4}^{4}4Since we use MLP in both GCN and Transformer, the terms and are unnecessary, but we put them to follow the original formulations.. Note that positional encoding using sine and cosine functions is not required, since features in an encounter are unordered.
Given Eq. 1 and Eq. 2, we can readily see that there is a correspondence between the normalized adjacency matrix and the attention map , and between the node embeddings and the value vectors . In fact, GCN can be seen as a special case of Transformer, where the attention mechanism is replaced with the known, fixed adjacency matrix. Conversely, Transformer can be seen as a graph embedding algorithm that assumes fully-connected nodes and learns the connection strengths during training. Given this connection, it seems natural to use Transformer as an algorithm to learn the underlying structure of visits.
Although Transformer can potentially learn the hidden encounter structure, without a single piece of hint, it must search the entire attention space to discover meaningful connections between encounter features. Therefore we propose Graph Convolutional Transformer (GCT), which, based on data statistics, restricts the search to the space where it is likely to contain meaningful attention distribution.
Specifically, we use 1) the characteristic of EHR data and 2) the conditional probabilities between features. First, we use the fact that some connections are not allowed in the encounter record. For example, we know that treatment codes can only be connected to diagnosis codes, but not to other treatment codes. Based on this observation, we can create a mask , which will be used during the attention generation step. has negative infinities where connections are not allowed, and zeros where connections are allowed.
Conditional probabilities can be useful for determining potential connections between features. For example, given chest pain, fever and EKG, without any structure information, we do not know which diagnosis is the reason for ordering EKG. However, we can calculate from EHR data that is typically larger than , indicating that the connection between the former pair is more likely than the latter pair. Therefore we propose to use the conditional probabilities calculated from the encounter records as the guidance for deriving the attention. After calculating , and from all encounter records for all diagnosis codes , treatment codes , and lab codes , we can create a guiding matrix when given an encounter record, as depicted by Figure 3. We use to denote the matrix of conditional probabilities of all features, normalized such that each row sums to . Note that GCT’s attention , the mask , and the conditional probabilities are of the same size.
Given and , we want to guide GCT to recover the true graph structure as much as possible. But we also want to allow some room for GCT to learn novel connections that are helpful for solving given prediction tasks. Therefore GCT uses the following formulation:
(3) | ||||
Self-attention: | ||||
Regularization: | ||||
(4) |
In preliminary experiments, we noticed that attentions were often uniformly distributed in the first block of Transformer. This seemed due to Transformer not knowing which connections were worth attending. Therefore we replace the attention mechanism in the first GCT block with the conditional probabilities
. The following blocks use the masked self-attention mechanism. However, we do not want GCT to drastically deviate from the informative , but rather gradually improve upon. Therefore, based on the fact that attention is itself a probability distribution, and inspired by Trust Region Policy Optimization
Schulman et al. (2015), we sequentially penalize attention of -th block if it deviates too much from the attention of -th block, using KL divergence. As shown by Eq. (4), the regularization terms are summed to the prediction loss term (e.g. negative log-likelihood), and the trade-off is controlled by the coefficient . GCT’s code will be made publicly available in the future.Choi et al. (2018) evaluated their model on proprietary EHR data that contained structure information. Unfortunately, to the best of our knowledge, there are no publicly available EHR data that contain structure information (which is the main motivation of this work). In order to evaluate GCT’s ability to learn EHR structure, we instead generated synthetic data that has a similar structure as EHR data.
The synthetic data has the same visit-diagnosis-treatment-lab results hierarchy as EHR data, and was generated in a top-down fashion. Each level was generated conditioned on the previous level, where the probabilities were modeled with the Pareto distribution. Pareto distribution follows the power law which best captures the long-tailed nature of medical codes. Using 1000 diagnosis codes, 1000 treatment codes, and 1000 lab codes, we initialized to follow the Pareto distribution, where , and
respectively denote diagnosis, treatment, and lab random variables.
is used to draw independent diagnosis codes , and is used to draw that are likely to co-occur with the previously sampled . is used to draw a treatment code , given some . is used to draw a lab code , given some and . Detailed description of generating the synthetic records and the link to download them are provided in Appendix A, and Appendix F, respectively. Code for generating the synthetic records will be open-sourced in the future. Table 1 summarizes the data statistics.Synthetic | eICU | |
---|---|---|
# of encounters | 50,000 | 41,026 |
# of diagnosis codes | 1,000 | 3,093 |
# of treatment codes | 1,000 | 2,132 |
# of lab codes | 1,000 | N/A |
Avg. # of diagnosis per visit | 7.93 | 7.70 |
Avg. # of treatment per visit | 14.59 | 5.03 |
Avg. # of lab per visit | 21.31 | N/A |
To test GCT on real-world EHR records, we use Philips eICU Collaborative Research Dataset^{5}^{5}5https://eicu-crd.mit.edu/about/eicu/ Pollard et al. (2018). eICU consists of Intensive Care Unit (ICU) records filtered for remote caregivers, collected from multiple sites in the United States between 2014 and 2015. From the encounter records, medication orders and procedure orders, we extracted diagnosis codes and treatment codes (i.e. medication, procedure codes). Since the data were collected from an ICU, a single encounter can last several days, where the encounter structure evolves over time, rather than being fixed as Figure 1. Therefore we used encounters where the patient was admitted for less than 24 hours, and removed duplicate codes (i.e. medications administered multiple times). Additionally, we did not use lab results as their values change over time in the ICU setting (i.e. blood pH level). We leave as future work how to handle ICU records that evolve over a longer period of time. Note that eICU does not contain structure information. For example, we know that cough and acetaminophen in Figure 1 occur in the same visit, but do not know if acetaminophen was prescribed due to cough. Table 1 summarizes the data statistics.
[leftmargin=5.5mm]
GCN: Given the true adjacency matrix , we follow Eq. (1) to learn the feature representations of each feature in a visit . The visit embedding (i.e. graph-level representation) is obtained from the placeholder visit node . This model will serve as the optimal model during the experiments.
GCN: Instead of the true adjacency matrix , we use the conditional probability matrix , and follow Eq. (1).
GCN: Instead of the true adjacency matrix , we use a randomly generated normalized adjacency matrix where each element is indepdently sampled from a uniform distribution between 0 and 1. This model will let us evaluate whether true encounter structure is useful at all.
Shallow: Each is converted to a latent representation
using multi-layer feedforward networks with ReLU activations. The visit representation
is obtained by simply summing all ’s. We use layer normalization Ba et al. (2016), drop-out Srivastava et al. (2014) and residual connections He et al. (2016) between layers.Deep: We use multiple feedforward layers with ReLU activations (including layer normalization, drop-out and residual connections) on top of shallow to increase the expressivity. Note that Zaheer et al. (2017) theoretically describes that this model is sufficient to obtain the optimal representation of a set of items (i.e., a visit consisting of multiple features).
In order to evaluate the model’s capacity to leverage the implicit encounter structure, we use prediction tasks based on a single encounter, rather than a sequence of encounters, which was the experiment setup in Choi et al. (2018). Specifically, we test the models on the following tasks. Parentheses indicate which dataset is used for each task.
[leftmargin=5.5mm]
Graph reconstruction (Synthetic): Given an encounter with features, we train models to learn feature embeddings , and predict whether there is an edge between every pair of features, by performing an inner-product between each feature embedding pairs and (i.e. binary predictions). We do not use Deep baseline for this task, as we need individual embeddings for all features ’s.
Diagnosis-Treatment classification (Synthetic): We assign labels to an encounter if there are specific diagnosis ( and ) and treatment code () connections. Specifically, we assign label "1" if the encounter contains - connection, and label "2" if the encounter contains - connection. We intentionally made the task difficult so that the models cannot achieve a perfect score by just basing their prediction on whether , and exist in an encounter. The prevalence for both labels are approximately . Further details on the labels are provided in Appendix B. This is a multi-label prediction task using the visit representation .
Masked diagnosis code prediction (Synthetic, eICU): Given an encounter record, we mask a random diagnosis code . We train models to learn the embedding of the masked code to predict its identity, i.e. a multi-class prediction. For Shallow and Deep, we use the visit embedding as a proxy for the masked code representation. The row and the column of the conditional probability matrix that correspond to the masked diagnosis were also masked to zeroes.
Readmission prediction (eICU): Given an encounter record, we train models to learn the visit embedding to predict whether the patient will be admitted to the ICU again during the same hospital stay, i.e., a binary prediction. The prevalence is approximately .
Mortality prediction (eICU): Given an encounter record, we train models to learn the visit embedding to predict patient death during the ICU admission, i.e., a binary prediction. The prevalence is approximately .
For each task, data were randomly divided into train, validation, and test set in 8:1:1 ratio for 5 times, yielding 5 trained models, and we report the average performance. Note that the conditional probability matrix
was calculated only with the training set. Further training details and hyperparameter settings are described in Appendix
C.Graph reconstruction | Diagnosis-Treatment classification | |||
Model | Validation AUCPR | Test AUCPR | Validation AUCPR | Test AUCPR |
GCN | 1.0 (0.0) | 1.0 (0.0) | 1.0 (0.0) | 1.0 (0.0) |
GCN | 0.5807 (0.0019) | 0.5800 (0.0021) | 0.8439 (0.0166) | 0.8443 (0.0214) |
GCN | 0.5644 (0.0018) | 0.5635 (0.0021) | 0.7839 (0.0144) | 0.7804 (0.0214) |
Shallow | 0.5443 (0.0015) | 0.5441 (0.0017) | 0.8530 (0.0181) | 0.8555 (0.0206) |
Deep | - | - | 0.8210 (0.0096) | 0.8198 (0.0046) |
Transformer | 0.5755 (0.0020) | 0.5752 (0.0015) | 0.8329 (0.0282) | 0.8380 (0.0178) |
GCT | 0.5972 (0.0027) | 0.5965 (0.0031) | 0.8686 (0.0103) | 0.8671 (0.0247) |
Graph reconstruction and diagnosis-treatment classification performance. Parentheses denote standard deviations. We report the performance measured in AUROC in Appendix
D.Synthetic | eICU | |||
Model | Validation Accuracy | Test Accuracy | Validation Accuracy | Test Accuracy |
GCN | 0.2862 (0.0048) | 0.2834 (0.0065) | - | - |
GCN | 0.2002 (0.0024) | 0.1954 (0.0064) | 0.7434 (0.0072) | 0.7432 (0.0086) |
GCN | 0.1868 (0.0031) | 0.1844 (0.0058) | 0.7129 (0.0044) | 0.7186 (0.0067) |
Shallow | 0.2084 (0.0043) | 0.2032 (0.0068) | 0.7313 (0.0026) | 0.7364 (0.0017) |
Deep | 0.1958 (0.0043) | 0.1938 (0.0038) | 0.7309 (0.0050) | 0.7344 (0.0043) |
Transformer | 0.1969 (0.0045) | 0.1909 (0.0074) | 0.7190 (0.0040) | 0.7170 (0.0061) |
GCT | 0.2220 (0.0033) | 0.2179 (0.0071) | 0.7704 (0.0047) | 0.7704 (0.0039) |
Readmission prediction | Mortality prediction | |||
Model | Validation AUCPR | Test AUCPR | Validation AUCPR | Test AUCPR |
GCN | 0.5121 (0.0154) | 0.4987 (0.0105) | 0.5808 (0.0331) | 0.5647 (0.0201) |
GCN | 0.5078 (0.0116) | 0.4974 (0.0173) | 0.5717 (0.0571) | 0.5435 (0.0644) |
Shallow | 0.3704 (0.0123) | 0.3509 (0.0144) | 0.6041 (0.0253) | 0.5795 (0.0258) |
Deep | 0.5219 (0.0182) | 0.5050 (0.0126) | 0.6119 (0.0213) | 0.5924 (0.0121) |
Transformer | 0.5104 (0.0159) | 0.4999 (0.0127) | 0.6069 (0.0291) | 0.5931 (0.0211) |
GCT | 0.5313 (0.0124) | 0.5244 (0.0142) | 0.6196 (0.0259) | 0.5992 (0.0223) |
Table 2 shows the graph reconstruction performance and the diagnosis-treatment classification performance of all models. Naturally, GCN shows the best performance since it uses the true adajcency matrix . Given that GCN is outperformed only by GCT, we can infer that the conditional probability is indeed indicative of the true structure. GCT, which combines the strength of both GCN and Transformer shows the best performance, besides GCN. It is noteworthy that GCN outperforms Shallow. This seems to indicate that for graph reconstruction, attending to other features, regardless of how accurately the process follows the true structure, is better than individually embedding each feature. Diagnosis-treatment classification, on the other hand, clearly penalizes randomly attending to the features, since GCN shows the worst performance. GCT again shows the best performance.
Table 3 shows the model performance for masked diagnosis prediction for both datasets. GCN could not be evaluated on eICU, since eICU does not have the true structure. However, GCN naturally shows the best performance on the synthetic dataset. Interestingly, Transformer shows comparable performance to GCN
, indicating the opposite nature of this task compared to graph reconstruction, where simply each feature attending to other features significantly improved performance. Note that the task performance is significantly higher for eICU than for the synthetic dataset. This is mainly due to eICU having a very skewed diagnosis code distribution. In eICU, more than 80% of encounters have diagnosis codes related to whether the patient has been in an operating room prior to the ICU admission. Therefore randomly masking one of them does not make the prediction task as difficult as for the synthetic dataset.
Table 4 shows the readmission prediction and mortality prediction performance of all models on eICU. As shown by GCT’s superior performance, it is evident that readmission prediction benefits from using the latent encounter structure. Mortality prediction, on the other hand, seems to rely little on the encounter structure, as can be seen from the marginally superior performance of GCT compared to Transformer and Deep. Even when the encounter structure seems unnecessary, however, GCT still outperforms all other models, demonstrating its potential to be used as a general-purpose EHR modeling algorithm. These two experiments indicate that not all prediction tasks require the true encounter structure, and it is our future work to apply GCT to various prediction tasks to evaluate its effectiveness.
In this section, we analyze the learned structure of both Transformer and GCT. As we know the true structure of synthetic records, we can evaluate how well both models learned via self-attention . Since we can view the normalized true adjacency matrix as a probability distribution, we can measure how well the attention map in Eq. (3) approximates using KL divergence .
Graph Reconstruction | Diagnosis-Treatment Classification | Masked Diagnosis Code Prediction | ||||
---|---|---|---|---|---|---|
Model | KL Divergence | Entropy | KL Divergence | Entropy | KL Divergence | Entropy |
GCN | 8.4844 (0.0140) | 1.5216 (0.0044) | 8.4844 (0.0140) | 1.5216 (0.0040) | 8.4844 (0.0140) | 1.5216 (0.0044) |
Transformer | 19.6268 (2.9114) | 1.7798 (0.1411) | 14.3178 (0.2084) | 1.9281 (0.0368) | 15.1837 (0.8646) | 1.9941 (0.0522) |
GCT | 7.6490 (0.0476) | 1.8302 (0.0135) | 8.0363 (0.0305) | 1.6003 (0.0244) | 8.9648 (0.1944) | 1.3305 (0.0889) |
Table 5 shows the KL divergence between the normalized true adjacency and the learned attention on the test set of the synthetic data while performing three different tasks. For GCN, the adjacency matrix is fixed to the conditional probability matrix , so KL divergence can be readily calculated. For Transformer and GCT, we calculated KL divergence between and the attention maps in each self-attention block, and averaged the results. We repeated this process for 5 times (on 5 randomly sampled train, validation, test sets) and report the average performance. Note that KL divergence can be lowered by evenly distributing the attention across all features, which is the opposite of learning the encounter structure. Therefore we also show the entropy of alongside the KL divergence.
As shown by Table 5, the conditional probabilities are closer to the true structure than what Transformer has learned, in all three tasks. GCT shows similar performance to GCN in all tasks, and was even able to improve upon in both graph reconstruction and diagnosis-treatment classification tasks. It is notable that, despite having attentions significantly different from the true structure, Transformer demonstrated strong graph reconstruction performance in Table 2. This again indicates the importance of simply attending to other features in graph reconstruction, which was discussed in Section 4.5 regarding the performance of GCN. For the other two tasks, regularizing the models to stay close to helped GCT outperform Transformer as well as other models. We show visual examples of attention behavior of both Transformer and GCT in Appendix E.
Learning effective patterns from raw EHR data is an essential step for improving the performance of many downstream prediction tasks. In this paper, we addressed the issue where the previous state-of-the-art method required the complete encounter structure information, and proposed GCT to capture the underlying encounter structure when the structure information is unknown. Experiments demonstrated that GCT outperformed various baseline models on encounter-based tasks on both synthetic data and a publicly available EHR dataset, demonstrating its potential to serve as a general-purpose EHR modeling algorithm. In the future, we plan to apply GCT on patient-level tasks such as heart failure diagnosis prediction or unplanned emergency admission prediction, while working on improving the attention mechanism to learn more medically meaningful patterns.
Doctor ai: Predicting clinical events via recurrent neural networks.
In Machine Learning for Healthcare Conference, pages 301–318, 2016a.Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks.
In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, pages 1903–1911. ACM, 2017.Learning vector representation of medical objects via emr-driven nonnegative restricted boltzmann machines (enrbm).
Journal of Biomedical Informatics, 2015.Gram: Graph-based attention model for healthcare representation learning.
In SIGKDD, 2017.Proceedings of the IEEE conference on computer vision and pattern recognition
, pages 770–778, 2016.We describe the synthetic data creation process in this section. As described in Section 4.1, we use the Pareto distribution to capture the long-tailed nature of medical codes. We also define , and to determine when to stop sampling the codes. The overall generation process starts by sampling a diagnosis code. Then we sample a diagnosis code that is likely to co-occur with the previous sampled diagnosis code. After the diagnosis codes are sampled, we iterate through the sampled diagnosis code to sample a treatment code that is likely to be ordered for each diagnosis code. At the same time as sampling the treatment code, we sample lab codes that are likely to be produced by each treatment code. The overall algorithm is described in Algorithm 1.
Note that we use to model the treatment being ordered due to a diagnosis code, instead of , which might be more accurate since a treatment may depend on the already ordered treatments as well. However, we assume that given a diagnosis code, treatments that follow are conditionally independent, therefore each treatment can be factorized by . The same assumption went into using , instead of .
Finally, among the generated synthetic encounters, we removed the ones that had less than 5 diagnosis or treatment codes, in order to make the encounter structure sufficiently complex. Additionally, we removed encounters which contained more than 50 diagnosis or treatment or lab codes in order to make the encounter structure realistic (i.e. it is unlikely that a patient receives more than 50 diagnosis codes in one hospital encounter). For the eICU dataset, we also removed the encounters with more than 50 diagnosis or treatment codes. But we did not remove any encounters for having less than 5 diagnosis or treatment codes, as that would leave us only approximately 7,000 encounter records, which are rather small for training neural networks.
This task is used to test the model’s ability to derive a visit representation (i.e. graph-level representation) that correctly preserves the encounter structure. As described in Section 4.4, this is a multi-label classification problem, where an encounter is assigned the label “1” if it contains a connected pair of a diagnosis code and a treatment code (i.e. was ordered because of ). An encounter is assigned the label “2” if it contains a connected pair of and . Therefore it is possible that an encounter is assigned both labels “1” and “2”, or not assigned any label at all.
Since we want to test the model’s ability to correctly learn the encounter structure, we do not want the model to achieve a perfect score, for example, by just predicting label “1” based on whether both and simply exist in an encounter. Therefore we adjusted the sampling probabilities to make this task difficult. Specifically, we set . Therefore the probability of an encounter containing a - connection is . The probability of an encounter contaning a - connection is . Therefore The overall probability of the two connection pairs occurring in an encounter are more or less the same, and the model cannot achieve a perfect score unless the model correctly identifies the encounter structure.
All models were trained with Adam Kingma and Ba [2014] on the training set, and performance was evaluated against the validation set to select the final model. Final performance was evaluated against the test set. We used the minibatch of size 32, and trained all models for 1,000,000 iterations (i.e. minibatch updates), which was sufficient for convergence for all tasks. After an initial round of preliminary experiments, the embedding size of the encounter features was set to 128. For GCN, GCN, GCN, Transformer, and GCT, we used undirected adjacency/attention matrix to enhance the message passing efficiency. All models were implemented in TensorFlow 1.13 Abadi et al. [2016], and trained with a system equipped Nvidia P100’s.
Tunable hyperparameters for models Shallow, Deep, GCN, GCN, GCN, and Transformer are as follows:
[leftmargin=5.5mm]
Adam learning rate ()
Drop-out rate between layers ()
Transformer used three self-attention blocks, which was sufficient to cover the entire depth of EHR encounters. Shallow used 15 feedforward layers and Deep used 8 feedforward layers before, and 7 feedforward layers after summing the embeddings. The number of layers were chosen to match the number of trainable parameters of Transformer and GCT. GCN, GCN and GCN used 5 convolution steps to match the number of trainable parameters of Transformer. Transformer used one attention head to match its representative power to GCN, GCN, and GCN, and so that we can accurately evaluate the effect of learning the correct encounter structure.
Tunable hyperparameters for GCT are as follows:
[leftmargin=5.5mm]
Adam learning rate ()
Drop-out rate between layers ()
Regularization coefficient ()
GCT also used three self-attention blocks and one attention head. All Hyperparameters were searched via bayesian optimization with Gaussian Process for 72-hour wall clock time based on one of the five randomly sampled train/validation/test set. Then the chosen hyperparameters were used for training models on all five sets. Hyperparameters used for each task is described below in Table 6.
Hyperparameters for graph reconstruction on the synthetic data. | |||||||
---|---|---|---|---|---|---|---|
GCN | GCN | GCN | Shallow | Deep | Transformer | GCT | |
Learning rate | 0.00045 | 0.0006 | 0.0003 | 0.00025 | - | 0.0007 | 0.0005 |
MLP dropout rate | 0.3 | 0.01 | 0.5 | 0.2 | - | 0.8 | 0.3 |
Post-MLP dropout rate | 0.2 | 0.02 | 0.005 | - | - | 0.001 | 0.1 |
Regularization coef. | - | - | - | - | - | - | 0.02 |
Hyperparameters for diagnosis-treatment classification on the synthetic data. | |||||||
GCN | GCN | GCN | Shallow | Deep | Transformer | GCT | |
Learning rate | 0.0001 | 0.0001 | 0.0001 | 0.0002 | 0.0008 | 0.00015 | 0.0001 |
MLP dropout rate | 0.2 | 0.3 | 0.5 | 0.02 | 0.01 | 0.5 | 0.85 |
Post-MLP dropout rate | 0.65 | 0.02 | 0.4 | - | 0.3 | 0.01 | 0.03 |
Regularization coef. | - | - | - | - | - | - | 0.05 |
Hyperparameters for masked diagnosis code prediction on the synthetic data. | |||||||
GCN | GCN | GCN | Shallow | Deep | Transformer | GCT | |
Learning rate | 0.0003 | 0.0007 | 0.0002 | 0.0007 | 0.0004 | 0.0003 | 0.0001 |
MLP dropout rate | 0.01 | 0.8 | 0.5 | 0.08 | 0.12 | 0.4 | 0.85 |
Post-MLP dropout rate | 0.88 | 0.005 | 0.5 | - | 0.75 | 0.5 | 0.6 |
Regularization coef. | - | - | - | - | - | - | 0.05 |
Hyperparameters for masked diagnosis code prediction on eICU. | |||||||
GCN | GCN | GCN | Shallow | Deep | Transformer | GCT | |
Learning rate | - | 0.0005 | 0.0001 | 0.0001 | 0.00012 | 0.0001 | 0.0009 |
MLP dropout rate | - | 0.5 | 0.3 | 0.3 | 0.4 | 0.87 | 0.5 |
Post-MLP dropout rate | - | 0.5 | 0.4 | - | 0.45 | 0.2 | 0.03 |
Regularization coef. | - | - | - | - | - | - | 50.0 |
Hyperparameters for readmission prediction on eICU. | |||||||
GCN | GCN | GCN | Shallow | Deep | Transformer | GCT | |
Learning rate | - | 0.00024 | 0.0001 | 0.0001 | 0.00011 | 0.0002 | 0.00022 |
MLP dropout rate | - | 0.3 | 0.7 | 0.63 | 0.05 | 0.45 | 0.08 |
Post-MLP dropout rate | - | 0.1 | 0.01 | - | 0.33 | 0.28 | 0.024 |
Regularization coef. | - | - | - | - | - | - | 0.1 |
Hyperparameters for mortality prediction on eICU. | |||||||
GCN | GCN | GCN | Shallow | Deep | Transformer | GCT | |
Learning rate | - | 0.0003 | 0.00013 | 0.0001 | 0.00015 | 0.0006 | 0.00011 |
MLP dropout rate | - | 0.85 | 0.9 | 0.25 | 0.01 | 0.88 | 0.72 |
Post-MLP dropout rate | - | 0.04 | 0.01 | - | 0.01 | 0.2 | 0.005 |
Regularization coef. | - | - | - | - | - | - | 1.5 |
Graph reconstruction | Diagnosis-Treatment classification | |||
Model | Validation AUROC | Test AUROC | Validation AUROC | Test AUROC |
GCN | 1.0 (0.0) | 1.0 (0.0) | 1.0 (0.0) | 1.0 (0.0) |
GCN | 0.8870 (0.0011) | 0.8865 (0.0005) | 0.9493 (0.0127) | 0.9475 (0.0135) |
GCN | 0.8806 (0.0009) | 0.8799 (0.0008) | 0.9230 (0.0053) | 0.9221 (0.0070) |
Shallow | 0.8578 (0.0010) | 0.8573 (0.0005) | 0.9575 (0.0116) | 0.9584 (0.0140) |
Deep | - | - | 0.9387 (0.0071) | 0.9374 (0.0041) |
Transformer | 0.8843 (0.0013) | 0.8844 (0.0008) | 0.9494 (0.0226) | 0.9493 (0.0210) |
GCT | 0.8936 (0.0012) | 0.8931 (0.0013) | 0.9626 (0.0146) | 0.9600 (0.0154) |
Readmission prediction | Mortality prediction | |||
Model | Validation AUROC | Test AUROC | Validation AUROC | Test AUROC |
GCN | 0.7403 (0.0078) | 0.7355 (0.0081) | 0.8971 (0.0047) | 0.8953 (0.0065) |
GCN | 0.7243 (0.0046) | 0.7259 (0.0080) | 0.8939 (0.0243) | 0.8941 (0.0220) |
Shallow | 0.6794 (0.0129) | 0.6734 (0.0101) | 0.9000 (0.0083) | 0.8972 (0.0038) |
Deep | 0.7478 (0.0124) | 0.7412 (0.0074) | 0.9101 (0.0057) | 0.9092 (0.0060) |
Transformer | 0.7333 (0.0065) | 0.7301 (0.0101) | 0.9089 (0.0121) | 0.9017 (0.0152) |
GCT | 0.7525 (0.0128) | 0.7502 (0.0114) | 0.9089 (0.0052) | 0.9120 (0.0048) |
Table 7 shows the graph reconstruction performance and the diagnosis-treatment classification performance of all models measured in AUROC. Table 8 shows the readmission prediction performance and the mortality prediction performance of all models measured in AUROC. We can readily see that GCT outperforms all other models in all tasks in terms of AUROC as well.
In this section, we compare the attention behavior of Transformer and GCT in two different context; graph reconstruction and masked diagnosis code prediction. We randomly chose an encounter record from the test set of the synthetic dataset, which had less than 30 codes in order to enhance readability. To show the attention distribution of a specific code, we chose the first diagnosis code connected to at least one treatment. Figure 4 shows Transformer’s attentions in each self-attention block when performing graph reconstruction. Specifically we show the attention given by the diagnosis code D_199 to other codes. The red bars indicate the true connections, and the blue bars indicate the attention given to all codes. It can be seen that Transformer evenly attends to all codes in the first block, then develops its own attention. In the second block, it successfully recovers two of the true connections, but attends to incorrect codes in the third block.
Figure 5 shows GCT’s attention in each self-attention blcok when performing graph reconstruction. Contrary to Transformer, GCT starts with a very specific attention distribution. The first two attentions given to the placeholder Visit node, and to itself are determined by the scalar value from Figure 3. However, the attentions given to the treatment codes, especially T_939 are derived from the conditional probability matrix . Then in the following self-attention blocks, GCT starts to deviate from , and the attention distribution becomes more similar to the true adjacency matrix. This nicely shows the benefit of using as a guide to learning the encounter structure.
Since the goal of the graph reconstruction task is to predict the edges between nodes, it may be an obvious result that both Transformer and GCT’s attentions mimic the true adjacency matrix. Therefore, we show another set of attentions from Transformer and GCT trained for the masked diagnosis code prediction task. Figure 6 shows Transformer’s attention while performing the masked diagnosis code prediction. Note that the diagnosis code D_294 is maksed, and therefore the model does not know its identity. Similar to graph reconstruction, Transformer starts with an evenly distributed attentions, and develops its own structure. Interestingly, it learns to attend to the right treatment in the third block, but mostly tries to predict the masked node’s identity by attending to other diagnosis codes, while mostly ignoring the lab codes.
Figure 7 shows GCT’s attention while performing the masked diagnosis code prediction task. Again, GCT starts with the conditional probability matrix , then develops its own attention. But this time, understandably, the attention maps are not as similar to the true structure as in the graph reconstruction task. An interesting finding is that GCT attends heavily to the placeholder Visit node in this task. This is inevitable, given that we only allow diagnosis codes to attend to treatment codes (see the white cells in Figure 3), and therefore, if GCT wants to look at other diagnosis codes, it can only be done by indirectly receiving information via the Visit node. And as Figure 6 suggests, predicting the identity of the masked code seems to require knowing the co-occurring diagnosis codes as well as the treatment codes. Therefore, unlike in the graph reconstruction task, GCT puts heavy attention to the Visit node in this task, in order to learn the co-occurring diagnosis codes.
The synthetic records used for the experiments can be downloaded via this link (https://www.dropbox.com/s/ojx9jr4yyvmfdum/synthetic.tar.gz). It is a compressed file, which you can decompress to obtain the following files.
[leftmargin=5.5mm]
visits_50k.p: This is a Python cPickle file. It is a List of encounter records, where each record is a List of a diagnosis code and the associated treatment-lab Lists. For example, [[1, []], [2, [[3, []], [4, [5, 6]]]]] describes a single encounter record. The first diagnosis code is “1”, and no treatment or lab codes follow. The second diagnosis code is “2”, and the treatment “3”, and treatment “4” are ordered because of the diagnosis “2”. Additionally, treatment “4” is followed by two lab codes, “5” and “6”. visits_50k.p consists of 500,000 encounter records that follow this format.
dx_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-dimensional vector, where the -th element represents .
dx_dx_cond_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000 matrix, where the -th element represents .
dx_proc_cond_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000 matrix, where the -th element represents .
dx_dx_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-dimensional vector, where the -th element represents .
multi_proc_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-dimensional vector, where the -th element represents .
multi_lab_probs.npy: This Python Numpy file corresponds to in Algorithm 1. It is a 1000-by-1000 matrix, where the -th element represents .
Note that and discussed in Appendix B correspond to the actual codes and , respectively. Therefore, for example, the -th element in dx_proc_cond_probs.npy equals approximately 0.2, and the -the element equals approximately 0.8.