MedGraph: Structural and Temporal Representation Learning of Electronic Medical Records

12/08/2019 ∙ by Bhagya Hettige, et al. ∙ Monash University 0

Electronic medical record (EMR) data contains historical sequences of visits of patients, and each visit contains rich information, such as patient demographics, hospital utilisation and medical codes, including diagnosis, procedure and medication codes. Most existing EMR embedding methods capture visit-code associations by constructing input visit representations as binary vectors with a static vocabulary of medical codes. With this limited representation, they fail in encapsulating rich attribute information of visits (demographics and utilisation information) and/or codes (e.g., medical code descriptions). Furthermore, current work considers visits of the same patient as discrete-time events and ignores time gaps between them. However, the time gaps between visits depict dynamics of the patient's medical history inducing varying influences on future visits. To address these limitations, we present MedGraph, a supervised EMR embedding method that captures two types of information: (1) the visit-code associations in an attributed bipartite graph, and (2) the temporal sequencing of visits through point processes. MedGraph produces Gaussian embeddings for visits and codes to model the uncertainty. We evaluate the performance of MedGraph through an extensive experimental study and show that MedGraph outperforms state-of-the-art EMR embedding methods in several medical risk prediction tasks.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Electronic medical records (EMR) contain rich clinical data from a patient’s stays in hospital. A high volume of EMRs is collected by hospitals, which can be used in medical risk prediction to improve the quality of personalised healthcare. EMR data forms a unique and complex data structure. An EMR represents a hospital visit, and it typically contains patient demographics (e.g. age, gender) and hospital utilisation information (e.g. duration/ward of stay). An unordered set of medical concepts (e.g., diagnosis, procedure and medication codes) are associated with each visit. These medical concepts are usually taken from pre-defined standards in healthcare such as International Classification of Diseases (ICD) and National Drug Codes (NDC). Moreover, EMRs accumulated over a period of time naturally form a sequence of visits of a patient’s hospitalisation history. Transforming EMRs into low-dimensional vectors has been an active research topic recently, as it enables these complex data in downstream machine learning algorithms to perform predictive healthcare tasks 

[DBLP:conf/ijcai/med2, DBLP:conf/kdd/med2vec, DBLP:conf/kdd/gram, DBLP:conf/nips/retain, DBLP:conf/nips/mime, DBLP:conf/kdd/dipole, miotto2016deep_pat, nguyen2016mathtt_deepr, DBLP:conf/ijcai/med1, DBLP:conf/aaai/med3]. Learning these structural visit-code associations and temporal visit-sequence influences are two important aspects of EMR embedding.

Considering structural visit-code associations, a patient’s visit contains a set of unordered medical codes. Existing methods, such as Med2Vec [DBLP:conf/kdd/med2vec], RETAIN [DBLP:conf/nips/retain] and Dipole [DBLP:conf/kdd/dipole]

, propose sophisticated deep learning models to derive latent representations for visits using multi-hot-encoded medical codes as inputs. The major limitations of these methods are two-fold. First, the visits are represented using a fixed vocabulary of medical codes, but in the real-world EMR systems, new or previously unseen medical concepts can be introduced to the data (due to revised versions of medical codes). Second, these methods do not capture demographics and utilisation information attached with visits (except for Med2Vec 

[DBLP:conf/kdd/med2vec]), and side information found in medical codes such as ICD text code descriptions. These additional attribute information are important in determining similar visits and similar medical codes.

Considering temporal visit sequences, EMRs are longitudinal medical events which are time-stamped. Each patient has a temporal sequence of visits, and these visits possess time-dependent relationships among them, i.e. previous visits in a patient’s history can have an influence on the next visit. The influence of historical visits on the next visit also degrades with time, so that more recent visits may have higher influence than older visits. Moreover, the time gap between consecutive visits is not fixed and the larger the time gap between two visits, the less related they are. Recurrent neural network (RNN) based architectures 

[DBLP:conf/nips/retain, DBLP:conf/nips/mime, DBLP:conf/kdd/dipole, DBLP:conf/pakdd/deepcare] have been proposed in previous work to learn the temporality of visits, but they consider the visit sequences as time-series by treating time as indexes (i.e. events are ordered periodically) and do not account for varying time gaps. On the contrary, visits are continuous-time events and there are varying influence of historical visits on the current visit based on time gaps between them.

In this paper, we propose , a novel EMR embedding method that leverages both structural and temporal information in EMRs to improve the embedding quality. encompasses a novel graph-based data structure to represent EMR, in which we represent visits and codes as nodes, and their different interactions as edges (see Figure 1). In Figure 1, each patient has a temporal visit sequence connected via dashed, directed edges, and each visit has a set of codes connected via thick, undirected edges. Denoting the sets of visits and codes as and respectively, - relationships form a bipartite graph with and node partitions. Each node carries supplementary information such as demographics and utilisation information in visits, and textual descriptions in codes, which makes - an attributed bipartite graph. The relationships form temporal sequences of nodes where is the time gap between two consecutive visits of a patient. Since the graph is an open data structure, it is extensible for new or unseen medical codes.

Taking advantage of the graph structure, effectively learns low-dimensional representations for visits and codes by capturing both visit-code associations (based on the graph structure) and hospitalisation history (based on temporal visit sequences) under one unified framework. Structurally, learns representations for visits and codes by considering them in an attributed bipartite graph. proposes to model temporal visit sequence information as a temporal point process, such as the Hawkes process [hawkes_process], to effectively capture and account for the varying influence of historical visits of different time gaps. These theoretical point process models often make strong assumptions about the generative process of the sequential event data which do not necessarily depict the real-world dynamics and also limit the expressive power of the model [DBLP:conf/kdd/rmtpp, point_processes]

. To automatically learn a more expressive representation for the varying historical visit influence when a real parametric model is unknown,

uses an RNN-based architecture, as in RMTPP [DBLP:conf/kdd/rmtpp], to model the conditional intensity function. Moreover, real-world EMR data are usually noisy due to manual data entry involved [10.1136/amiajnl/emr_noisy]. We have designed

to be robust to noise by learning node embeddings (i.e. visits and codes) as probability distributions (e.g. Gaussians). To the best of our knowledge,

is the first EMR embedding method that models uncertainty of the visit and code embeddings as Gaussians. The contributions of our proposed EMR embedding method, , are threefold:

  1. A customised graph-based data structure designed for EMR data, that naturally captures both the visit-code co-location information (-) structurally as an attributed bipartite graph and the visit sequence information () temporally as a continuous event sequence.

  2. A novel approach to effectively and efficiently learn relationships using point process based sequence learning method with RNN based conditional intensity function, considering varying influence of historical visits on the next visit.

  3. An extensive experimental study on two real-world EMR datasets that shows ’s superiority over state-of-the-art embedding methods on a number of tasks: 30-day readmission prediction, mortality prediction, medical code classification, medical code visualisation and uncertainty modelling.

Figure 1: The data structure. is the -th visit of -th patient.

2 Related Work

Representation learning on electronic medical records aims at learning low-dimensional vectors for hospital visits and medical codes, through their interactions in terms of visit-code relationship and visit-visit sequencing information.

Several embedding models have been proposed to learn from visit-code relationships in EMR data [DBLP:journals/titb/deep_ehr]. Most of these work, including Med2Vec [DBLP:conf/kdd/med2vec], RETAIN [DBLP:conf/nips/retain], Deepr [nguyen2016mathtt_deepr], GRAM [DBLP:conf/kdd/gram] and Dipole [DBLP:conf/kdd/dipole], capture the visit-code associations by constructing the input visit vector as a multi-hot encoded medical code vector. MiME [DBLP:conf/nips/mime] assumes the hierarchical structure of EMR data and represents patients, visits, diagnoses, procedures and medications in a hierarchy in the stated order. In real-world EMRs, though, this hierarchical granularity of medical codes is often not found. GCT [DBLP:journals/corr/GCT] attempts to address this challenge with a graph data structure to learn the implicit hierarchical relations of codes using Transformers. However, none of these work consider the rich attributes of visits/codes. An exception to this are Med2Vec [DBLP:conf/kdd/med2vec], which considers visit demographics, and MNN [DBLP:conf/ijcai/temp1], which incorporates clinical notes. Still, they do not consider code attributes. In contrast, GRAM [DBLP:conf/kdd/gram] models EMR with a convex combination of the embeddings of the code and its ancestors on the ontology tree, and it strictly depends on this ontology structure. But not all the medical codes form such rich ontology graphs.

Existing EMR representation learning work captures the temporality of visit sequences using either Skip-gram with multi-layer perceptron (MLP) architectures 

[DBLP:conf/ijcai/time_aware_emr, DBLP:conf/kdd/med2vec] or RNN-based architectures [DBLP:conf/mlhc/doctorai, DBLP:conf/nips/retain, DBLP:journals/jamia/rnn_emr, DBLP:conf/nips/mime, DBLP:conf/naacl/rnn_med, DBLP:conf/kdd/dipole, DBLP:conf/pakdd/deepcare, DBLP:conf/ijcai/temp2]. However, Skip-gram is only capable of predicting neighbouring visits within a predefined number of time steps without any particular order in visits in the context window. RNN models assume that the events are recorded periodically, and cannot effectively capture the varying time gaps between events. A recent work, PacRNN [DBLP:conf/ijcai/temp1], proposes a continuous-time model using point processes for a specialised healthcare task, i.e. ranked diagnosis code recommendation and time prediction of the next visit.

uses a graph-based data structure to model EMR data, and we show that the graph can capture much meaningful data compared to the existing approaches. Our approach also effectively captures the temporality of historical visits of patients via the proposed temporal point process model. Experimental results also show that produces more effective EMR embeddings.

3 : Medical Data Graph Embedding

In this section, we describe the algorithm for . Without loss of generality, the algorithm will be discussed for a single patient for simplicity of the notations. Figure 1 is a heterogeneous graph with two types of nodes and two types of edges. These two edge types denote two distinguishing information about the visits, i.e. codes in the visits and temporal visit sequences. Therefore, we dismantle the two edge types. Accordingly, we extract a subgraph from Fig. 1 for a single patient, and decompose it into an attributed bipartite graph (Fig. 1(a)) and a temporal (timestamped) sequence graph (Fig. 1(b)).

3.1 Notations of

As denoted in Fig. 2, assume a patient has a time ordered sequence of visits , where is the length of the visit sequence and each visit has an unordered set of medical codes with the code set . denotes the set of code nodes, such as diagnosis codes. We construct the graph data structure for this data as follows.

- attributed bipartite graph: This graph has two node partitions, visits and codes , and a set of edges, , where is an edge denoting the link between visit and code . Each visit has a -dimensional visit attribute vector with visit demographic and utilisation information, such as age, gender and length of stay. Each code has a -dimensional code attribute vector with code supplementary medical information, such as ICD code description as text or multi-hot ICD ontology ancestors.

temporal sequence: The temporally ordered sequence of visits of a patient is denoted as , where for visit , is the timestamp and is ground-truth for the underlying auxiliary task (optional) with classes.

Each visit and code, , is represented as a low-dimensional Gaussian embedding in a shared embedding space, where , with embedding dimension capturing visit-code associations and time-gap-based influence of historical visits. We learn as a diagonal covariance vector, , instead of a covariance matrix to reduce the number of parameters to learn.

(a) - graph
(b) sequences
(c) Structural and temporal learning of
Figure 2: architecture

3.2 Architecture

Figure 2 shows the architecture of our proposed algorithm, . Given an EMR dataset, we transform these data into a graph-based representation. Then, we consider - attributed bipartite graph and learn visit and code similarities based on the graph structure proximity. We further improve the visit embeddings by learning temporal visit sequences, , as events occurring in continuous-time modelled through a temporal point process which captures the mutual excitation phenomenon among temporally ordered events [hawkes_process], so that we can learn the varying influence (due to time gaps between visits) of the historical visits. is an end-to-end risk prediction tool with a supervised task plugged into the output layer of the RNN. Alternatively, if there is no specific supervision task, can also be learned in an unsupervised manner with only structure learning and temporal sequence learning.

3.2.1 Structural learning for - graph (Fig. 1(a))

In contrast to previous EMR embedding methods which collects codes in a visit as a multi-hot vector, employs a bipartite graph to denote edges between and . Due to this versatile structure, we can plug in auxiliary attribute data at nodes resulting in an attributed bipartite graph.

Let be an edge between visit and code (i.e. ) with attributes and , respectively. We project the two node types to a uniform semantic latent space using two transformation matrices denoted by, and , where is the intermediate vector dimension, for visit and code domains, respectively.

(1)

Then, we apply another layer of linear transformations to the

-dimensional vectors to obtain Gaussian embeddings in a common embedding space for the two node types denoted by and :

(2)
(3)
(4)
(5)

where , , and

denote the shared mean and variance encoders for both node types. To obtain positive covariance for interpretablity of uncertainty, we add one in

functions.

Embedding into a common -dimensional embedding space enables similarity computation between two heterogeneous nodes. Since the embeddings are Gaussians, we measure the Wasserstein distance, specifically -nd Wasserstein distance () between the embeddings. By computing distance, we can preserve transitivity property in the embedding space [DBLP:conf/kdd/dvne]. As a result, when we model the explicit visit-code relations, visit-visit and code-code associations are also implicitly modelled in the embedding space. For example, if both and are linked to , then it is highly likely that and are similar, and we implicitly capture this similarity by preserving triangle inequality property in the embedding space. We define as the distance for our embeddings of visit and code in the embedding space. Similarly to DVNE [DBLP:conf/kdd/dvne], modelling only the diagonal covariance vectors results in . Therefore, the distance computation [givens1984classW2] simplifies to:

(6)

Then, we define joint probability between the two node distribution representations as the likelihood of the existence of a link between them by:

. Since the graph is unweighted, we can define the prior probability using the structural information observed in the graph as:

. To preserve this proximity measure in the embedding space, we minimise the distance between the prior and observed probability distributions for all edges observed in - graph. Since and

are discrete probability distributions, we define the structural loss function as:

(7)

3.2.2 Temporal learning for sequences (Fig. 1(b))

The objective of our temporal sequence modelling is to learn time-gap-based influence of historical visits on the next visit of a patient. Thus, we model the visits as continuous-time events with a point process model [hawkes_process, kingman2005poisson, point_processes] to capture the varying historical visit influence. Typical parametric point process models establish strict assumptions on the generative process of the events, so that these models have restricted expressive power, and do not guarantee to reflect the real-world data. Hence, following the idea of RMTPP [DBLP:conf/kdd/rmtpp], we learn the visit influence using an RNN-based architecture to model a flexible and expressive marked temporal point process and capture the visit sequence dynamics automatically without being restricted to strict parametric assumptions.

RNNs use output from the hidden units of the current time step as inputs for the next time step. Consequently, the network can memorise the influence of each past data event through the hidden state vectors . Thus, we use to represent the influence of the history up to the -th visit. We construct a vector to model the markers of the marked temporal point process using the learned visit embedding from the previous section where we treat the covariance vector as a noise:

(8)

For the input layer of the RNN cell (bottom block in Fig. 1(c)), we feed event information (i.e. ) and timing information (i.e. ) about the current visit event. Since we are interested in modelling time gaps between the visits, we set as the time gap between the previous and the current visit. We update the RNN cell to output an effective hidden state vector at the current time step, , using the current visit event and the influence from the memory carried out from the past visit events ():

(9)

where , , , , is the time vector dimension and is the RNN hidden state dimension. We define the conditional intensity function to model the point process by:

(10)

where and . The exponential function is a non-linear transformation which guarantees positive intensity values. The first term comprises the historical influence, the second term denotes the current influence, and the final term emphasises base intensity in the intensity function. Then, we define the likelihood of the next visit occurring at time given the historical visit sequence up to time by:

(11)

Accordingly, given a temporal visit sequence for a patient with visits, we can define the temporal loss function as the negative log-likelihood of observing the visit sequence (i.e. maximise the likelihood of observing the sequence by minimising the negation):

(12)

3.2.3 Auxiliary supervision task

is an end-to-end supervised model. Accordingly, we can train the model to predict medical risk in future. Medical risk prediction is an important task for personalised healthcare [DBLP:conf/nips/retain, DBLP:conf/nips/mime]. Therefore, we incorporate an auxiliary medical risk prediction task. This allows prediction of future outcomes of a patient at a given point in time to supplement predictive healthcare. We assume that the hidden state of the current visit, , not only carries the information from the current visit itself, but also memorises the time-gap-based influence of past visits through point process modelling. For simplicity, we describe a classification task in which the outcome for the visit is where is the number of classes. We use to predict the class label (i.e. future medical risk outcome) as follows:

(13)

where and . Then, we compute the classification loss for a visit sequence of a patient using cross-entropy:

(14)

3.3 Unified Training and Model Optimisation

is an end-to-end medical risk prediction model, which exploits - graph structure and temporal sequences in improving predictive performance of an underlying medical risk prediction task. For each patient, the unified loss of predictive model is defined as:

(15)

where

are hyperparameters which control learning from structural, temporal and underlying healthcare prediction task, respectively.

can be trained in an unsupervised manner (setting ) when there is no auxiliary supervision task, e.g., when learning general-purpose embeddings for an exploratory analysis of EMR data [DBLP:conf/kdd/med2vec].

To optimise the structural loss computation, we use the negative sampling [DBLP:conf/nips/word2vec, DBLP:conf/www/line] approach, which selects number of negative - edges for each positive edge.

4 Experiments

In this section, we evaluate the performance of on two real-world EMR datasets and compare its performance against several state-of-the-art embedding methods along with two variant versions of . Source code for will be made available upon publication.

4.1 Datasets

We use two real-world, cohort-specific proprietary EMR datasets in the experimental study: heart failure (HF) and chronic liver disease (CL). We remove patients with less than two visits. Brief statistics of the two datasets are shown in Table 1. Both EMR datasets are extracted from the same hospital, in which ICD-10-CM diagnosis codes and in-house procedure codes are used. For the visits we extract patient demographics (e.g. age, gender, ethnicity, birth country, etc.) and hospitalisation utilisation information (e.g. length of stay, admission source, etc.) as visit attributes. For the medical codes, we use tf-idf vectors of ICD-10-CM code descriptions for the diagnoses, and tf-idf vectors of code descriptions provided by the hospital for the in-house procedures as code attributes.

Dataset HF CL
Data collection time period 2010-2017 2000-2017
# of patients 10,713 3,830
# of visits 204,753 122,733
Avg. # of visits per patient 18.51 29.84
# of unique medical codes 8,541 8,382
 # of unique diagnosis codes 6,278 6,010
 # of unique procedure codes 2,263 2,372
Avg. # of medical codes per visit 5.27 5.02
Max # of medical codes per visit 98 100
Table 1: Statistics of real-world, cohort-specific EMR datasets.

4.2 Baselines

We compare to several state-of-the-art EMR embedding methods to evaluate the performance on several risk prediction tasks. We choose Skip-gram based (Med2Vec [DBLP:conf/kdd/med2vec]) and RNN-based (Dipole [DBLP:conf/kdd/dipole] and RETAIN [DBLP:conf/nips/retain]) EMR embedding methods for comparison111GRAM [DBLP:conf/kdd/gram] is not chosen a baseline as the in-house procedure codes in our proprietary datasets do not form an ontology on which GRAM depends.. We also choose a state-of-the-art general-purpose graph embedding model, GCN [DBLP:conf/iclr/gcn], to learn - attributed bipartite graph.

Med2Vec [DBLP:conf/kdd/med2vec] is a Skip-gram based EMR embedding method that produces both code- and visit-level representations by predicting medical codes appearing in neighbouring visits. It captures visit demographics, and we feed the visit attribute vector for each visit.

Dipole [DBLP:conf/kdd/dipole] is an attention-based bidirectional RNN framework, which takes the influence of historical visits via the trained attention weights.

RETAIN [DBLP:conf/nips/retain] is an end-to-end RNN-based healthcare prediction model with a reverse-time attention mechanism, which models the influence of previous visits and important medical codes in them.

GCN [DBLP:conf/iclr/gcn], the Graph Convolutional Networks, is a recent state-of-art semi-supervised graph embedding approach which learns by aggregated neighbourhood information. We use GCN layers (unsupervised) to model the - bipartite relations. Since GCN only supports homogeneous graphs, we ignore the node heterogeneity and construct attribute vectors of visits (i.e. ) and codes (i.e. 

) by pre- and post-padding with zeros respectively.

4.3 Our Approaches

We conduct a comprehensive ablation study to evaluate the effectiveness and importance of the two important components of , namely structure learning and point process based temporal learning. We denote each variant model with a negation () in front of the ablated component in learning. We summarise the variants in Table 2.

Notation
for the predictive model; for the unsupervised model
Table 2: Our approaches trained with Eq. 15: .

MedGraph is our full model that incorporates both structural and temporal aspects (Fig. 2) by modelling - bipartite relations and time-gap-based temporal point process of visit sequences.

MedGraph(S,T) is a simple RNN model which treats visit events as an equally spaced time-series and thus does not model temporal point processes. It learns from - bipartite associations.

MedGraph(S,T) does not learn any structural information in the - graph, so does not learn medical conditions of visits. It implements the proposed temporal point process based RNN model to learn time-gap-based influence of historical visits.

4.4 Hyperparameter Settings

For all baseline models, we use as the visit and code embedding dimension. Since produces mean and covariance vectors, we halve the embedding dimension to in for a fair comparison by learning the same number of parameters per node (i.e.  and ). For all the methods, we use a batch size of 128 for visits and 32 for visit sequences. For , we set the number of negative edges as 10, and are tuned to be optimal on a validation set. We use Adam optimiser with a learning rate fixed at 0.001. Other parameters for baseline models are referred from the papers and tuned to be optimal.

4.5 Medical Risk Prediction

We perform two real-world medical risk prediction tasks, 30-day readmission prediction and mortality prediction, to assess the model’s effectiveness in predictive healthcare. For , we use each risk prediction task as the auxiliary supervision task (cf Section 3.2.3) in this set of experiments in an end-to-end manner (

). Since Med2Vec, Dipole and GCN are not designed for medical risk prediction tasks, we first learn visit embeddings using these methods, and then use these fixed visit vectors as inputs to the prediction model (i.e. XGBoost classifier 

[DBLP:conf/kdd/xgboost]). For both tasks, we randomly split the visit sequences into 2 parts with a 4:1 ratio. The former is used to train the embedding models, while the latter is held out for evaluation. We randomly sample 25% of visit sequences from the held-off dataset as the validation set, and the rest as the test set.

4.5.1 30-day readmission risk prediction

A significant fraction of hospital readmissions are unplanned. They represent lowered quality of care and incur significant costs to hospitals [journal/jama/readmission]. Accurate readmission prediction allows hospitals to target intervention strategies on high-risk patients to prevent unplanned readmissions. Given a visit, we train the models to predict 30-day readmission after a patient’s discharge from the current visit/admission [journal/jama/readmission, nguyen2016mathtt_deepr], thus casting it as a binary classification task. Both cohort datasets are imbalanced and the majority class is positive, with a prevalent 30-day readmission rate of 69.8% and 76.2% respectively. Therefore, AUC is a better measure to compare the models [DBLP:conf/nips/retain]. We plot AUC scores for HF and CL datasets in Figure 3.

As can be seen from the plot, , as well as its two variants, clearly and consistently outperform all compared state-of-the-art baseline methods by a significant margin. This shows the effectiveness of each component in our model. The contribution of each type of information is experimentally validated by the performance improvement gains of the full model over the non-structural () and time-series ( with no time gaps) variants of our method. Our observation is also supported by a previous study on readmissions [DBLP:conf/kdd/readmission], which shows that readmission of a patient is dependent on various factors including patient demographics and past medical incidents.

Among the EMR baseline models, Med2Vec performs the best on HF, and second best on CL even though it does not model temporality of the visits, which can be attributed to its involvement of visit demographics in embedding. Capturing visit-code associations is shown to be effective, as GCN’s performance is competitive compared with several EMR embedding baselines in this task.

From the performance gaps between RNN-based methods (Dipole, RETAIN and ) and our time-gap-based point process models ( and ), we can see that time gap information modelling is important in making an accurate prediction of the readmission risk in both datasets. Actually, our proposed point process model learns time gaps between historical visits via point process modelling (Eq. 12). Consequently, we can successfully predict a patient’s readmission with a higher prediction accuracy at the 30-day time threshold.

Figure 3: Performance on 30-day readmission risk prediction task.

4.5.2 Mortality prediction

Early identification of patients who are at high risk of death can assist hospitals to appropriately allocate resources to these patients and mitigate the mortality risk. Given a patient’s visit sequence, we train the models to predict the patient’s death in the next admission [knaus1991apache], thus casting it as a binary classification task. We report area under receiver operating curve (AUC) and average precision (AP) on both HF and CL datasets (with a prevalent mortality rate of 27.0% and 40.7% respectively) in Table 3.

From the table, we can see that shows superior performance over all the other methods across both datasets. This demonstrates the effectiveness of the visit representations learned by in predicting the death of patients in their next visit. The possible reasons for the superiority of our method include: the incorporation of attribute information in visits and codes and adopting temporal point processes in modelling visit-visit sequences. Intuitively, according to a related study [DBLP:conf/amcis/mortality], the outcome of a patient’s death depends on their demographics (i.e. ), current health conditions (i.e. - structure) and historical health conditions (i.e.  sequence), all of which are captured by our model.

Moreover, outperforms both its variant versions, which showcases the effectiveness of the proposed structural and temporal learning in improving predictive performance of the visit representations in the mortality prediction task.

The poor performance of GCN and the performance gain of the RNN-based methods (i.e. Dipole and RETAIN) over GCN, suggest that the structural learning is not sufficient to capture visit similarities in this task and temporality is important. However, Med2Vec performs better due to Skip-gram based code co-occurrence learning and incorporation of visit attributes, though it ignores temporality.

HF CL
Method AUC AP AUC AP
Med2Vec 0.7091 0.4752 0.7224 0.6255
Baselines Dipole 0.6464 0.4145 0.6411 0.5487
RETAIN 0.6908 0.5402 0.7034 0.6836
GCN 0.5581 0.3177 0.6172 0.4956
0.7002 0.6548 0.7385 0.7123
Ours 0.7131 0.6751 0.7306 0.7118
0.7205 0.6853 0.7415 0.7143
Table 3: Performance on the mortality prediction task. Best result is bolded and second-best is underlined.

4.6 Medical Code Representation Analysis

In this section, we explore the medical code embeddings to evaluate their descriptiveness. We propose a novel evaluation task to quantitatively analyse the informativeness of the code embeddings and we qualitatively study the interpretability of the learned code embeddings by . We learn general-purpose embeddings with with in Eq. 15.

4.6.1 Multi-class code classification

The Clinical Classifications Software (CCS)222https://www.hcup-us.ahrq.gov/toolssoftware/ccsr/ccs_refined.jsp divides the ICD codes into a number of clinically meaningful categories. Thus, if the code embeddings are predictive of their CCS categories, then the embeddings have learnt useful latent information. With this hypothesis, we perform multi-class code classification to predict the corresponding medical concept classes of medical codes produced by CCS.

First, each method learns code embeddings, and then a logistic regression (LR) classifier is trained on the code embeddings to classify each code into their associated CCS class. We select the 10 most common CCS classes in the datasets, since some CCSs are rare. We randomly sample different percentages of diagnosis codes

as the training set for the classifier, and use the rest for evaluation. We report micro- and macro-F1 scores which have been widely used in the evaluation of multi-class classification tasks [DBLP:conf/www/line, DBLP:conf/kdd/dvne].

is capable of learning supplementary attributes of codes. Thus, it should be able to produce more expressive code embeddings compared to the methods which do not use attributes. To conduct a fair evaluation with these baselines, we train without code attributes (i.e. let ). This variant is denoted . We exclude in this task as it does not learn code embeddings. Micro-F1 scores for HF are presented in Figure 4. The trend is similar for macro-F1 scores, and in the CL dataset, which we omit for brevity reasons.

As can be seen in Fig. 4, and produce embeddings that are highly descriptive of their medical context, consistently and substantially outperforming the non-attributed version, , by a large margin of improvements in the code classification task. GCN also showcases a significant performance improvement over non-attributed models. This demonstrates the effectiveness of incorporating code-level attributes and structural visit-code relations in producing high quality code embeddings. Thus, superiority of our method over and the rest of the baselines can be attributed to two factors: (1) the use of standard code descriptions as supplementary code attributes, and (2) the learning of code-visit associations through the graph-based data structure, as opposed to multi-hot medical code vectors.

We also see that consistently outperforms all the three evaluated EMR embedding methods, producing more meaningful embeddings. The difference between our method and these methods is the way we learn visit-code associations. These baseline methods construct multi-hot medical code vectors to represent visits, and then these are used as features in their models to learn visit-code relations ignoring the inherent graph structure. On the contrary, our models learn visit-code associations through the structural information in the - bipartite graph. Thus, ours is capable of learning not only the code co-location in visits through local neighbourhood, but also the similar code neighbourhoods through global connectivity due to transitivity property in distance. This shows that the structural learning of visit-code associations we proposed is effective in producing meaningful code embeddings. Among the baselines, Med2Vec shows superiority over Dipole and RETAIN, which is attributed to its neighbouring code learning technique (similar to Skip-gram) within a predefined context window.

, which learns - structure and model time-series visits with no time gaps, is slightly surpassed by , since injecting time-gap-based temporal learning of medical history enables to learn additional useful latent patterns in EMRs. For example, when a patient is diagnosed with cancer at an earlier visit, learns that the patient may revisit the hospital for chemotherapy or other related procedures via the patient’s medical history.

Figure 4: Multi-class code classification task. Improvements are statistically significant for

in a paired t-test.

4.6.2 Interpretation of code embeddings

Interpretability of the learned code embeddings is important in various applications in medical domain, including healthcare analysis tasks. In Figure 5, we project the 128-dimensional mean vector of embeddings of ICD-10-CM diagnosis codes trained with on HF dataset (500 codes belonging to the top 10 CCS classes) into 2 dimensions using t-SNE, for visualisation [DBLP:tsne]. We also publish an interactive plot for further analysis333https://bhagya-hettige.github.io/MedGraph. Colour of a node indicates the associated CCS class.

Overall, clusters codes belonging to most CCS clinical concepts with clear boundaries. Moreover, there are several overlapped CCS classes due to broad definitions of CCS. learns interesting latent relations between codes, especially in these overlapped CCS regions. For brevity reasons, we only analyse two such scenarios: (1) codes of “tuberculosis pneumonia” and “bacterial infections” are overlapping, forming a cluster showing their clinically closer relationships [10.1093/qjmed/pneumonia], and (2) “benign neoplasm” related codes further separates into more granular classes (focussed on different organs) within a broader CCS class, identifying the inherent differences in these sub classes.

Figure 5: 2-D visualisation of codes. Code’s colour is its CCS class.

4.7 Uncertainty Modelling of EMR

Different from the existing EMR embedding methods, learns the uncertainty of visit and code embeddings as Gaussians. In this task, we study the nature of learned uncertainty terms and its intuition in the real-world EMRs. We learn interpretable diagonal covariances with non-negative values (cf Section 3). We define the average variance across the 10 largest dimensions as a node’s variance [DBLP:conf/kdd/dvne]. We conduct the following analysis on the HF dataset. To obtain general-purpose embeddings in this exploratory analysis, we learn the embeddings in an unsupervised manner, setting in Eq. 15.

Visit Embedding Uncertainty (Fig. 5(a)): We divide all the patients into 20 buckets based on their visit counts. For each bucket, we compute the average visit variance and plot it against the number of visits. When the number of visits of a patient increases the average variance of visit embeddings decreases (Fig. 5(a)). Intuitively, when a patient has a longer medical history, their visit embeddings are more comprehensive and descriptive, thus less uncertain.

Code Embedding Uncertainty (Fig. 5(b)): We divide the ICD-10 codes into 10 buckets based on their degrees (i.e. number of visits a code is connected to). We compute the average variance of each bucket and plot it against the . We observe that the average variance decreases, when the code degree increases in Fig. 5(b). Intuitively, lower degree codes (i.e. when the code rarely occurs) have less structural information to learn, hence their embeddings have a higher degree of uncertainty. In contrast, higher degree codes occur more frequently, so they possess a lower embedding uncertainty as they are more expressive in terms of the structure.

(a) Visit uncertainty.
(b) Code uncertainty.
Figure 6: Analysis of uncertainty of the embeddings. Trendlines show the trend of the results.

5 Conclusion

In this work, we propose , an effective EMR embedding framework for visits and codes. introduces a graph-based data structure to naturally capture both visit-code co-location information structurally, and visit sequencing information temporally. Based on this structure learns from the visit-code bipartite graph and exploits temporal point processes to capture medical history in an end-to-end manner. supports visit- and code-level attributes. We further improve the expressive power of by modelling uncertainty of the embeddings. Results on two real-world EMR datasets demonstrate that produces meaningful representations for EMRs, significantly outperforming state-of-the-art EMR embedding methods on a number of medical risk prediction tasks.

References