Misdiagnosis, where a diagnostic decision is made inaccurately, widely occurs. There are approximately 12 million US adults experiencing diagnosis errors every year, half of which could be harmful Singhbmjqs-2013-002627. As many as 40,500 adult patients in an intensive care unit (ICU) die with an misdiagnosis annually winters2012diagnostic. A major source of misdiagnosis is sub-optimal interpretation and usage of clinical data schiff2009diagnostic, DiagError, Singh2013Types, doi:10.7326/0003-4819-145-7-200610030-00006. Nowadays, physicians are overwhelmed by a large amount of medical data including laboratory tests, vital signs, clinical notes, medication prescriptions, etc. Among all kinds of clinical data, laboratory tests play an important role. According to American Clinical Laboratory Association, laboratory tests guide more than 70% of the diagnostic decisions. Unfortunately, comprehensively understanding the laboratory test results and discovering the underlying clinical implications are not easy. Incorrect interpretation of laboratory tests is a major breakdown point in the diagnostic process doi:10.7326/0003-4819-145-7-200610030-00006, DiagError, schiff2009diagnostic, Singh2013Types.
The reasons why laboratory tests are difficult to understand are two-fold. First, missing values are pervasive. It is typical that at a certain time point, only a small subset of laboratory tests are examined, leaving the values of most tests missed. These data missing prevents physicians from getting a full picture of patients’ clinical states, leading to sub-optimal decisions. Second, the laboratory test values have a complex multivariate time-series structure: during an in-hospital stay, multiple laboratory tests are examined at a particular time, and the same test may be examined multiple times at different time points. These multivariate temporal data exhibits complicated patterns along the dimensions of both time and tests. Learning these patterns is highly valuable for diagnosis, but it is technically challenging.
In this work, we study how to leverage the ability of machine learning (ML) in automatically distilling patterns from complex, noisy, incomplete and irregular laboratory test data to address the above-mentioned issues, and build an end-to-end diagnostic model to assist with diagnosis decisions. Previous studies have applied ML to perform diagnosis based on laboratory tests. In these approaches, three major tasks – handling missing values, discovering patterns from multivariate time-series and predicting diseases – were often conducted separately. However, these three tasks are tightly coupled and can mutually benefit each other. On one hand, better imputation of missing values leads to the discovery of more informative patterns, which boosts the accuracy of diagnosis. On the other hand, during model training, the supervision of diagnosis provides a guidance of pattern discovery, which further influences the imputation of missing values, tailoring the discovered patterns and imputed values to be suitable for a diagnosis task. Performing these tasks separately fails to consider their synergistic relationships and hence leads to sub-optimal solution. Another limitation exists in previous studies is that they often proposed in a discriminative structure that cannot well address the missing value problem and learn generalizable patterns in principle.
In this paper, we develop an end-to-end deep neural model to perform diagnosis based on laboratory tests. Our model seamlessly integrates three tasks together: imputing missing values, discovering patterns from multivariate time-series data and predicting diseases, and perform them jointly. Our model combines two major learning paradigms in machine learning: generative learning and discriminative learning, where the generative learning component is utilized to deal with missing values and discover robust and generalizable patterns and the discriminative learning component is used for predicting diseases based on the patterns discovered in generative learning. We evaluate the proposed model on 46,252 patient visits in the ICU and demonstrate that our model achieves (1) significantly () better diagnosis performance than baseline models, (2) better imputation of missing values, and (3) better discovery of patterns from the laboratory test data.
Lasko et al. proposed to use a Gaussian process to model the longitudinal Electronic Medical Records (EHR) and used a standard auto-encoder (AE) to learn hidden features from raw inputs lasko2013computational, and Ghassemi et al. further introduced a multi-task Gaussian process to model the clinical data ghassemi2015multivariate. Che et al. and Miotto et al. took advantage of denoising auto-encoder (DAE) vincent2008extracting
to learn hidden representations and then used these representations to make diagnoseche2015deep, miotto2016deep. In recent years, recurrent neural network (RNN) has demonstrated its superiority in modeling longitudinal data, like natural language mikolov2010recurrent and speech signals graves2013speech. RNN has also been utilized for medical diagnosis. Lipton et al.
proposed a medical diagnosis model based on long short-term memory (LSTM)hochreiter1997long recurrent neural networks and obtained performance better than some strong baselines lipton2015learning, lipton2015phenotyping
. In this work, missing values were addressed by heuristic forward- and back-filling. Their later work continued the LSTM-based diagnosis, but used a missing value indicator as a part of the inputs. They found missing value patterns can help with diagnosis performancelipton2016modeling. Choi et al. proposed a ’Doctor AI’ that used previous diagnoses to predict future diagnosis choi2016doctor. In their another work, a neural attention mechanism was further introduced choi2016retain. Che et al.
proposed a diagnosis model based on stacked DAE and LSTM, and Gradient Boosting Trees were introduced to make the learned features more interpretableche2015distilling
. Their later work modified the structure of Gated Recurrent Units (GRU) to deal with incomplete inputsche2016recurrent.
All the above work are based on discriminative models, which cannot address the missing value problem very well in principle. A generative model usually works better in this aspect. A typical generative model is the Gaussian Mixture Model (GMM), with which missing values can be easily addressed by the Expectation-Maximum (EM) algorithmGhahramani1994Learning, ghahramani1994supervised. For instance, Marlin et al. employed GMMs to discovery patterns from clinical data and conducted mortality prediction Marlin2012Unsupervised. These generative models, however, are mostly shallow, linear and Gaussian. Recently, researchers proposed several deep generative models, e.g., variational auto-encoder (VAE) kingma2013auto and variational RNN (VRNN) chung2015recurrent. Compared to conventional generative models, VAE and VRNN can model more complex conditional distributions, hence representing more complex patterns. Our work will utilize these deep generative models. However, generative models are not task-oriented yakhnenko2005discriminatively, bernardo2007generative, yogatama2017generative. Therefore, we propose an end-to-end approach that combines the advantages of both types of models and train them in a joint fashion. This can be seen as either a generative learning with the discriminative target as a guidance, or a discriminative learning with the generative target as a regularization.
The data used in this study, opened by MIMIC-IIIjohnson2016mimic, is publicly available. It was derived from the laboratory tests of 46,252 patients. It contains both in-hospital and outpatient records. Each in-hospital episode has 1 to 39 corresponding ICD-9 (International Classification of Diseases, 9th Edition) codes, and in this study, only the primary diagnoses will be considered. It amounts to 2,789 different diagnoses and 513 unique laboratory tests. As some diagnoses and tests are quite rare, we limit our study to the 50 most frequent diagnoses and the 50 most frequent laboratory tests. We group the test results by day, and finally, we get 30,931 temporal sequences of in-hospital records, and each of them is labeled by a disease ID, from 0 to 49. The lengths of the temporal sequences range from 2 to 171, and we focus on the latest 100 days. Figure 1 shows the number of samples over the 50 disease IDs. We random split the dataset for 5 times, and each time we keep the proportion of training (Train), development (Dev), testing (Test) sets as 65%:15%:20%. Hence, the numbers of samples in these three sets are 20,105, 4,640, 6,186, respectively.
Some of the tests are valued by discrete categories, like ”ABNORMAL” and ”NORMAL”. We change these categories to integers, like 0 for ”ABNORMAL”, 1 for ”NORMAL”. Test results are normalized by the Z-normalization, i.e. values of each test are subtract by the mean and divided by the standard deviation. Note that a patient cannot do every test every day, so missing values are pervasive in our data. In Figure2, we present an example of a patient’s laboratory test records. It can be seen that there are a lot of missing values. A simple statistics show that in the whole dataset, the average missing value rate is about 54%, i.e. on average, only 27 of the 50 laboratory tests have values in a patient’s one-day record. In our experiments, initially we impute the missing values with 0. After applying Z-normalization, the mean of values change to zero. So, zero imputation equals to mean imputation. Moreover, since our models are situated in neural network framework, zero inputs will not introduce additional bias in computation. Note that in baseline models, this zero-imputation behaves as the solution of missing value problem, while in our models, it behaves as indicators of the missing values, and missing values will be further addressed by deep generative models.
Given N i.i.d data , each is a temporal sequence of in length, i.e. . Any , , where is the dimension of input features. Meanwhile, there is a class label for each , and our purpose is to predict the class labels accurately. Specifically, is the longitudinal laboratory test records of a patient, and each is a one-day record. is the primary diagnosis. We propose two models in this study, denoted by VAE+NN and VRNN+NN, the former a static model that can demonstrate the contribution of deep generative models, while the latter a temporal model that extends the deep generative learning approach to learn long-term temporal dependency.
In this model, we address the temporal records by simply averaging the vectors at all time points, i.e.. Although the averaging operation alleviates the missing value problem to some extent, the rate of the missing values in
is still high (about 29%). To further deal with the missing values and capture the complex patterns in data, we propose a VAE+NN model, where a VAE is the generative model used to handle missing values and discover patterns, and a standard neural network (NN) is used as a classifier, as shown in Figure3 (a).
Referring to the idea of VAE kingma2013auto, suppose that is generated from a latent variable
, then the joint probability is defined as:
Let the prior over
is a centered isotropic multivariate Gaussian distribution, i.e.. Further assume that with a fixed , also follows a Gaussian distribution, and suppose different dimensions of are independent, then the generation of is defined as:
where and are parameters of the conditional distribution.
is a feed-forward neural network. Then, the expectation of posterior distributionis fed into a discriminative network to perform classification, formulated as follows:
where is the softmax function, written by .
To better capture the long-term dependency of clinical records, previous work introduced RNN into this task lipton2015learning, lipton2016modeling. However, the internal transition structure of RNN is entirely deterministic which cannot efficiently model the complexity in highly structured sequential data. Therefore, Chung et al. extended the VAE idea to a recurrent framework for modeling high-dimensional sequences, leading to a variational RNN (VRNN)chung2015recurrent. We propose a VRNN+NN model as shown in Figure 3 (b), which involves a VRNN to generate sequential hidden features and an NN model to make decisions based on the average of these hidden features.
Assuming that the generation of each in is conditioned on a latent variable and the previous state , the joint probability is defined as follows.
where is the prior of the latent variable and is the conditional probability of the observations. We assume that the prior distribution follows a Gaussian distribution, and different dimensions of are independent, i.e. the covariance matrix is diagonal. Therefore, the prior distribution is defined as follows.
represents the mean and the standard variance of the prior distribution, respectively. Again, assume different dimensions ofare independent, then the conditional probability for the observation is written by:
where and are parameters of the generating distribution. and are neural networks. The hidden states are updated by LSTM hochreiter1997long cells denoted by , so the recurrence equation is defined as:
We take the average of the hidden states, , as the input of a discriminative NN .
Inference & Learning
The involving of neural networks makes the true posterior is intractable. Resort to variational inference methods, we use to approximate the true posterior which is defined in the fashion of:
where is a feed-forward neural network. Denote the parameters involved in VAE as , learning target is which can be redefined as the lower bound of :
Since we assume different dimensions of are independent, . Missing items in should be dropped in the probability, i.e. , is the indexes of missing items. So, the generative loss can be rewritten as follows:
Also, there is a discriminative loss from classification, which is the cross entropy between the true disease ID
and model predicted posterior probability.denotes the parameters in discriminative network.
Therefore, the overall loss function of the VAE+NN model is given by the sum of the two costs, denote.
where is a trade-off parameter of generative loss, because our final target is classification, and thus the generative loss can be taken as a regularization term.
Similarly, the true posterior is intractable, is used to approximate it. and generative loss are defined as follows:
Also, since different dimensions of are independent, . Denote the indexes of missing values in is , then:
The overall loss function is also defined as the combination of generative loss and discriminative loss.
In this section, we will build several baseline models for the comparative study, denoted by NN, AE+NN, RNN+NN, respectively. NN and AE+NN models are used to compare with the VAE+NN model, we’d like to see if deep generative models can have a better performance when representing single feature vectors. RNN+NN model is similar as the model structure in previous studies lipton2015learning, lipton2016modeling, and it is used to compare with our VRNN+NN model.
To demonstrate the superiority of VAE, we present an AE+NN baseline that is based on the standard auto-encoder, shown in Figure 4 (b). AE is similar to VAE, but its structure is deterministic, so it is less generative. In this model, we also combine the loss from AE and NN. We use mean squared error (MSE) as the training objective for AE. Also, missing items are not included in the loss function. In summary, the model and the learning target are defined as follows:
The RNN+NN model is shown in Figure 4 (c). In this model, an RNN processes the raw temporal features, and the average of the hidden state is used as the input of the NN. also denotes the recurrent computation of the LSTM cells. The model can be formulated as follows:
In our experiments, models are implemented on Tensorflow r1.0. Alland
are feed-forward neural networks with one hidden layer and ReLU activations. The size of hidden layer is set to 64. We use Adamkingma2014adam as the optimizer, where the learning rate is 0.0005 and the learning rate decay is 0.99. The trade off parameter is set to 0.5 in all the experiments.
Generally, F1 score is defined as , where is the number of correct positive results divided by the number of all positive results, and is the number of correct positive results divided by the number of positive results that should have been returned. Since our task is a 50-class classification, we use three kinds of F1 scores: micro-F1, macro-F1, macro-F1-weighted (macro-F1-w). Micro-F1 is computed from flatten and . Macro-F1 is the arithmetic mean of F1 scores of different classes. Macro-F1-weighted is the weighted mean of F1 scores of different classes, and the weight of class is defined as , where is the number of samples of class in testing set. F1 score ranges from 0 to 1, and the best performance achieves when F1 score equals 1. We simulate a blind prediction, i.e. equal probabilities on different classes, the results are micro-F1 is 0.111, macro-F1 is 0.004, macro-F1-weighted is 0.022.
AUC is the area under the ROC curve. Similarly, we also use three AUCs in evaluation: micro-AUC, macro-AUC, macro-AUC-weighted (macro-AUC-w). AUC also ranges from 0 to 1, and the best performance achieves when AUC equals 1. And in the blind prediction, all AUCs equal 0.5.
Table 1 shows the diagnosis performances of three sets of experiments. The top set of values are the performances of different models on the diagnosis task, measured by different variants of F1 values and AUCs. Additionally, to test if the joint training can result in better representations compared to unsupervised generative models, the representations derived from VAE, VAE+NN, VRNN and VRNN+NN are used to train a new NN model for diagnosis decision. Performances are presented in the middle set. Finally, in the bottom set, to compare our VRNN+NN model’s ability in dealing with missing values with some heuristic imputation methods, we investigate four imputation methods here. ”zero” is the default approach for baseline models, and ”last&next”, ”row mean” and ”NOCB” are three best known imputation methods according to the previous study engels2003imputation: ”last&next” is the average of the last known and next known values; ”row mean” is the mean of patient’s values before and after; ”NOCB” is the next observation carried backward.
To evaluate if the performance in Table 1
is reliable, we apply paired t-test to check if the performance difference among different models is statistically significant. The results are shown in Table2.
Since deep generative models can reconstruct input data, we conjecture that our VRNN+NN model has the potential to impute missing values better. To test this conjecture, we first randomly drop 10% of the values from the original data, and then use the trained VRNN+NN to impute the intentionally dropped values. The results in terms of MSE are shown in Table 3, where the MSE values of the heuristic imputation methods are also presented. The paired t-test results of these methods are shown as well.
|Diagnosis performances of different models|
|NN||0.376 0.004||0.221 0.003||0.347 0.005||0.939 0.001||0.905 0.001||0.913 0.001|
|AE+NN||0.366 0.004||0.219 0.002||0.344 0.002||0.938 0.001||0.903 0.002||0.912 0.001|
|VAE+NN||0.374 0.003||0.226 0.005||0.352 0.004||0.941 0.000||0.908 0.001||0.916 0.001|
|RNN+NN||0.395 0.004||0.248 0.003||0.373 0.003||0.945 0.003||0.918 0.004||0.923 0.003|
|VRNN+NN||0.426 0.002||0.291 0.006||0.407 0.002||0.958 0.000||0.937 0.000||0.938 0.001|
|Performance of features derived from different models (with a simple NN classifier)|
|E()(VAE)||0.363 0.004||0.195 0.004||0.326 0.003||0.936 0.001||0.896 0.003||0.906 0.002|
|E()(VAE+NN)||0.380 0.004||0.228 0.004||0.353 0.002||0.943 0.001||0.911 0.003||0.918 0.002|
|(VRNN)||0.406 0.003||0.261 0.003||0.381 0.003||0.953 0.000||0.928 0.001||0.930 0.001|
|(VRNN+NN)||0.427 0.003||0.297 0.004||0.410 0.003||0.958 0.001||0.936 0.001||0.937 0.000|
|Performance with different missing value imputation methods|
|RNN+NN(zero)||0.395 0.005||0.248 0.003||0.374 0.002||0.945 0.003||0.918 0.004||0.923 0.003|
|RNN+NN(last&next)||0.385 0.002||0.233 0.003||0.360 0.002||0.941 0.001||0.912 0.002||0.918 0.001|
|RNN+NN(row mean)||0.393 0.003||0.243 0.005||0.369 0.001||0.945 0.002||0.917 0.003||0.923 0.002|
|RNN+NN(NOCB)||0.384 0.003||0.231 0.002||0.359 0.001||0.941 0.002||0.911 0.003||0.917 0.002|
|VRNN+NN||0.426 0.002||0.291 0.006||0.407 0.002||0.958 0.000||0.937 0.001||0.938 0.001|
|VAE+NN vs. NN||*||*||*||**|
|VAE+NN vs. AE+NN||*||*||*||*|
|RNN+NN vs. VAE+NN||**||**||**||**||**|
|VRNN+NN vs. RNN+NN||***||***||***||***||***||***|
|E()(VAE+NN) vs. E()(VAE)||**||***||***||**||***||***|
|(VRNN+NN) vs. (VRNN)||**||***||***||***||***||***|
|Different imputation methods||Performance comparison|
|Imputation Methods||Imputation Error||Comparison||P-values|
|Zero||0.909 0.112||VRNN+NN vs. Zero||***|
|Last&next||0.434 0.110||VRNN+NN vs. Last&next||*|
|Row mean||0.541 0.114||VRNN+NN vs. Row mean||***|
|NOCB||0.547 0.112||VRNN+NN vs. NOCB||***|
From the diagnosis performances shown in the top set of Table 1, it can be observed that considering temporal dependencies, e.g., by RNN or VRNN, the diagnosis perform can be significantly improved, confirming that long-term dependency is an important property of clinical data. Secondly, involving deep generative models provides consistent performance improvement (see AE+NN vs. VAE+NN, and RNN+NN vs. VRNN+NN). This is an encouraging result and shows that the generative models (VAE and VRNN) learn better representations compared compared to the less-generative counterparts (AE and RNN).
Compared the improvement provided by the generative modeling, it can be observed that when dealing with single averaged feature vectors (VAE), the improvement is not very significant. This indicates that the data missing problem with the average vectors is not very severe, and the zero imputation is reasonably good. When dealing with temporal vector sequences (VRNN), the improvement is highly significant (). This is understandable, as the data missing problem is more serious in this case.
According to the results of the NN classifier with different features (the middle set of Table 1), we can find that the representations learned from VAE+NN performs significantly better than those learned from VAE; similarly, the representations learned from VRNN+NN are better than those learned from VRNN. This implies that the joint training that considers the classification target can lead to better feature learning. This is not surprising as the learning is now more task-oriented.
The comparison of different missing value imputation methods, as shown in the bottom set of Table 1, demonstrates that our VRNN+NN model outperforms all the heuristic imputation methods, confirming that the generative model is a more principled way for the missing data treatment. This conclusion can be drawn more explicitly from the results shown in Table 3, where the performance of different imputation methods is compared by the imputation error directly. This advantage on data imputation by itself is quite useful in practice. For example, it can help to complete the incomplete laboratory test data and give a rough range of the missing values that may assist physicians to have a better analysis.
Limitations and Future Works
Some limitations exist in current study and need to be addressed in future work. First, we only consider the most frequent 50 diagnoses, leaving the rest 2739 diseases untested. This is because there is severe data imbalance between different diseases. About 35% of 2789 diseases have only one sample. Even for the top 50 diseases selected in the study, the data distribution is still quite biased. As shown in Figure 1, the disease with ID=0 has 3566 samples, while the disease with ID=49 only has 178 samples. Unbalanced data casts a big challenge for model training, and the resultant models tend to score higher for frequent labels than infrequent ones. In this study, we didn’t apply any additional preprocessing, like over-sampling and under-sampling, to address the data unbalance problem. Future work will address the data unbalance problem, by either sampling or re-weighting for infrequent classes.
Second, after looking into the VRNN+NN model performance of each disease, as shown in Figure 5, we find that the F1 scores and AUCs of different diseases vary dramatically. Some diagnoses, like ”Alcohol withdrawal” (ID=37) and ”Single liveborn without cesarean section” (ID=0), are quite accurate, while some others, like ”Gram-neg septicemia NEC” (ID=42) and ”Mal neo upper lobe lung” (ID=47), obtain F1 scores close to 0. This phenomenon implies that some diseases are difficult to detect by our model. It may result from our limitation to 50 most frequent laboratory tests, or the deficiency of other clinical data. More laboratory tests will be considered and other clinical information will be included in the future study.
Third, we only used the in-hospital records that have diagnosis labels, but there are also a lot of outpatient test results in MIMIC-III that posses no labels. In our future work, we will investigate the semi-supervised learning to enhance the capability of our model by utilizing the unlabelled data.
The longitudinal, incomplete, and noisy laboratory test data casts a big difficulty for automatically medical diagnosis. In this study, we proposed to utilize a deep sequential generative model in the form of VRNN to deal with the missing data problem and learn the complex temporal patterns in clinical data, and this generative model is trained jointly with the back-end discriminative model for making diagnosis decision. This leads to an end-to-end system that takes advantage of both generative learning and discriminative learning. Our experiments show that the VRNN+NN model significantly () surpasses all the baselines and its non-temporal version, the VAE+NN model. We also find that deep generative models can help with imputing missing values in clinical laboratory tests and distilling more informative patterns for diagnosis, and the combination of generative and discriminative models leads to improvement on both feature learning and diagnosis decision. Future work involves addressing data unbalance and utilizing unlabeled data.
Author contributions statement
S.Z. and P.X. conceived and designed this study. S.Z. processed the data and performed the experiments. S.H., P.X., D.W. wrote the paper. D.W. and E.P.X. take responsibility for the paper as co-senior authors. All authors reviewed the manuscript.
Supplementary information accompanies with this paper.
Competing financial interests: The authors declare no competing financial interests.