Uncertainty-Gated Stochastic Sequential Model for EHR Mortality Prediction

03/02/2020 ∙ by Eunji Jun, et al. ∙ Korea University ETRI 0

Electronic health records (EHR) are characterized as non-stationary, heterogeneous, noisy, and sparse data; therefore, it is challenging to learn the regularities or patterns inherent within them. In particular, sparseness caused mostly by many missing values has attracted the attention of researchers, who have attempted to find a better use of all available samples for determining the solution of a primary target task through the defining a secondary imputation problem. Methodologically, existing methods, either deterministic or stochastic, have applied different assumptions to impute missing values. However, once the missing values are imputed, most existing methods do not consider the fidelity or confidence of the imputed values in the modeling of downstream tasks. Undoubtedly, an erroneous or improper imputation of missing variables can cause difficulties in modeling as well as a degraded performance. In this study, we present a novel variational recurrent network that (i) estimates the distribution of missing variables allowing to represent uncertainty in the imputed values, (ii) updates hidden states by explicitly applying fidelity based on a variance of the imputed values during a recurrence (i.e., uncertainty propagation over time), and (iii) predicts the possibility of in-hospital mortality. It is noteworthy that our model can conduct these procedures in a single stream and learn all network parameters jointly in an end-to-end manner. We validated the effectiveness of our method using the public datasets of MIMIC-III and PhysioNet challenge 2012 by comparing with and outperforming other state-of-the-art methods for mortality prediction considered in our experiments. In addition, we identified the behavior of the model that well represented the uncertainties for the imputed estimates, which indicated a high correlation between the calculated MAE and the uncertainty.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

In the past decade, there have been growing interests and researches in applying machine learning (ML) to clinical domains, particularly on electronic health records (EHR) data analysis for intensive care units

[jagannatha2016structured, choi2016medical]. However, owing to the nature of physiological EHR data, they generally involve a substantial number of missing values because of the lack of collection (i.e., unexpected accidents such as equipment damage) or documentation (i.e., an irregular recording across medical variables and even time) [wells2013strategies]. Such an unfavorable characteristic constrains the use of conventional ML models, which commonly assume fully observed and fixed-sized observations in practice [yadav2018mining].

To address this issue, some previous studies have directly modeled observations with missing values, for example, transforming them into a time series of distributions over the possible values [zheng2017resolving]. However, not only does this show a low performance under a high rate of missing data, it also requires separate modeling for different datasets. Meanwhile, numerous imputation methods have also been proposed to fill in missing values. Broadly, existing imputation methods can be categorized into (i) deterministic or (ii) stochastic approaches, depending on the presence of randomness in the imputation process [kalton1986treatment]. Deterministic imputation methods determine only one possible value for each missing value using the model parameters and/or conditions, thus resulting in a unique imputed value for each observation. This approach ranges from statistical methods (i.e., simple mean [kantardzic2011data], median imputation [acuna2004treatment]

, and ratio imputation) to ML approaches, such as expectation maximization (EM)

[garcia2010pattern],

-nearest neighbor (KNN), matrix factorization

[koren2009matrix], and matrix completion [mazumder2010spectral]

. In recent years, centered on deep learning (DL), recurrent neural networks (RNNs) such as long short-term memory (LSTM) and a gated recurrent unit (GRU) have shown remarkable a performance in modeling the temporal dependencies of a clinical time series and explicitly estimating the missing values, thus being regarded as de facto methods

[che2018recurrent, yoon2017multi].

However, stochastic methods, for example, multivariate imputation by chained equations (MICE) [azur2011multiple], possess some inherent randomness and take into account the distributions, thus allowing the generation of samples. More recently, [luo2018multivariate] exploited the adversarial learning framework [goodfellow2014generative]

, in which a generator imputes missing values based on other observed values and a discriminator criticizes whether the completed values (by applying both the observed and imputed values) are realistic. In addition, a variational autoencoder (VAE)

[kingma:vae, rezende2014stochastic] has also been used for a time series imputation, where the temporal dynamics is separately modeled through a Gaussian process [fortuin2019multivariate] or an RNN [jun2019].

Although previous imputation methods have shown a reasonable performance, there is no doubt that an erroneous or improper imputation of missing data can degrade the performance in downstream tasks [kreindler2016effects]. Therefore, this has resulted in the need to take account of the fidelity for the imputed values. That is, imputed values with low and high fidelity should be treated differently during the modeling. To the best of our knowledge, most existing methods exploiting the aforementioned imputation techniques do not consider the fidelity of the imputed values in the downstream tasks [che2018recurrent, yoon2017multi, cao2018brits, luo2018multivariate].

In this study, we explicitly utilize the uncertainty for imputing missing values represented in terms of variance as the fidelity and propose a novel uncertainty-gated stochastic sequential model for a clinical time-series prediction. Inspired by the success of stochastic RNN models that introduce a stochastic gradient variational Bayes (SGVB) approach [rezende2014stochastic, kingma:vae] into an RNN sequence model, we take advantage of its capabilities to capture the underlying sequential structure and temporally generate missing values for a multivariate time-series data imputation. In addition to providing probabilistic imputation estimates, as a result of stochastic inference, we estimate their corresponding uncertainty from a latent space and further propagate it within the GRU cells in a time-series modeling for mortality prediction. Note that, to harness the rich representational power of a latent space, our proposed method implicitly uses the imputed time series for prediction. The main contributions of this work are as follows:

  • To the best of our knowledge, our study is the first to use an extended RNN with stochastic units to provide probabilistic imputation estimates with uncertainty.

  • We propose a novel GRU cell, called GRU-U, that exploits uncertainty-gated attention and further leverages attention weights for a reliable mortality prediction.

  • We simultaneously conduct a missing value imputation and further prediction task jointly in an end-to-end manner.

  • We evaluated our model on real-world healthcare datasets and achieved state-of-the-art results for the mortality prediction task, validating the effectiveness of the proposed approach in a clinical setting.

(a)
(b)
Fig. 1: Graphical illustrations of (a) the whole architecture of the stochastic recurrent imputation method and (b) each step of the VRNN.

Ii Related Work

In the past decades, numerous imputation methods have been proposed to handle missing values that are sampled irregularly and/or sparsely in a multivariate time series. Conventional imputation methods can largely be divided into three classes. The first comprises statistical imputation methods, from simple mean [kantardzic2011data], median [acuna2004treatment], and ratio imputation methods to classical statistical time-series models including the auto-regressive integrated moving average (ARIMA) [ansley1984estimation], which eliminates the non-stationary parts in a sequence and fits a parameterized stationary model. However, these statistical imputation methods have a limitation in terms of inadequately modeling the temporal sequence and deterministically imputing missing values without any stochastic factors.

The second class comprises a variety of ML-based imputation methods have been developed for better missing value estimation, such as the EM algorithm [garcia2010pattern], KNN, matrix factorization [koren2009matrix], and matrix completion [mazumder2010spectral]. Furthermore, MICE [azur2011multiple] is widely used in practice by iteratively applying the aforementioned ML-based methods and averaging the results. Although such methods exploit randomness added as a stochastic approach, they rarely consider the uncertainty information.

Finally, as the thrid class, i.e., deep learning-based imputation methods, the RNNs [che2018recurrent, yoon2017multi] have more recently been proven to achieve successes in modeling temporal dependencies and imputing the sequence, particularly within the healthcare domain. Hence, we focus more on the RNN-based state-of-the-art imputation methods of a clinical time series, which are usually leveraged together with further downstream, i.e., classification/regression.

GRU-D [che2018recurrent]

assumes that missing variables can be derived by combining the mean imputation and forward filling using the last observation. For this, a trainable temporal decaying factor toward the global mean is introduced from the time interval information. In addition, the masking vector is directly modeled with the time-series data inside the GRU cell, allowing missing patterns to also be modeled internally. GRU-D achieves a superior performance in various clinical tasks but has a strong assumption regarding the data, which may not be well-suited to typical time-series datasets.

As another RNN-based imputation method, M-RNN [yoon2017multi] utilizes a bi-directional RNN to reconstruct missing values by operating both within streams (i.e.

, interpolation) and across streams (

i.e., imputation). The imputed values in M-RNN are treated as constants, which cannot be sufficiently updated.

As another bidirectional approach, similar to an M-RNN, BRITS [cao2018brits] also uses bi-directional recurrent dynamics by considering the forward and backward directions on a given time-series to solve the error delay problem until the presence of the next observation. It further formulates a feature-based estimation as a correlation among variables in addition to a history-based estimation. At the same time, it considers the imputed values as variables and updates them during back-propagation as opposed to M-RNN, thereby resulting in a SOTA performance in the healthcare domain.

Despite these remarkable realizations of deep learning-based imputation models, it should be noted that these methods have a major drawback in that they do not investigate the uncertainty for the imputation estimates. Meanwhile, [luo2018multivariate] recently proposed the GRU-I that uses the adversarial learning framework to impute a multivariate time-series and explicitly leverage it for further classification. However, this method has a limitation in that the imputation and downstream steps are separated, and thus training dose not occur in an end-to-end manner. In addition, although an uncertainty occurs in an input space, it is not utilized explicitly for classification tasks.

By adding stochasticity into the hidden states of an RNN, this study is the first to provide probabilistic imputation estimates with uncertainty, and further apply uncertainty-gated attention within the GRU cell, as well as consider the effectiveness in terms of an in-hospital mortality prediction task.

Iii Methods

In this section, we describe our proposed uncertainty-gated stochastic sequential model that extends a variational RNN (VRNN) [chung2015recurrent] for mortality prediction using EHR data. In particular, by leveraging a VRNN as our base model, we devise a novel network that handles representation learning, missing value imputation, and in-hospital mortality prediction simultaneously. Notably, we propose a new type of GRU cell, in which temporal information encoding in hidden states is updated by propagating the uncertainty for the inferred distribution over the variables. The overall architecture of the proposed model is shown in Fig. 1.

Our proposed method consists of three parts: (i) stochastic recurrent missing value imputation, (ii) uncertainty-gated attention, and (iii) in-hospital mortality prediction.

Iii-a Data Representation

Given a multivariate time series with variables over time points, we denote it as , where represents the -th observation of all variables observed at timestamp , and is the -th element or variable in . In this setting, because the time-series includes missing values, we introduce the masking vector across the time-series, , with the same size of , to mark which variables are observed or missing. In particular, if is observed, , otherwise. Considering a masking vector, we define a new multivariate time series including missing values, , as follows:

(1)

where * indicates an unobserved value to be estimated by the proposed imputation method. Initially, we set * in to zero [luo2018multivariate, nazabal2018handling]. In addition, we maintain the time interval, , defined as the difference between the last observation and the current timestamp following the equations for each variable :

(2)

Given a clinical time series dataset for subjects, we define an in-hospital mortality prediction as a binary classification problem with labels . To avoid cluttering, and without loss of generality, we simply use functional notation for a patient’s EHR at time , ignoring a superscript .

Iii-B Stochastic Recurrent Missing Value Imputation

Inspired from a finding indicating that the use of stochastic representations for hidden states in RNNs helps improve the time-series modeling [fabius2014variational, chung2015recurrent], we adopt a VRNN [chung2015recurrent] in our base model architecture.

The VRNN is a probabilistic extension of an RNN with stochastic units. In particular, the hidden states of RNNs include latent random variables by combining the elements of the VAE. This allows modeling their distributional characteristics in a latent space, where the underlying structure of sequential data can be better represented. Based on the estimated distribution of latent variables, it becomes possible to generate input values, with which we can impute missing values accordingly.

The overall stochastic imputation process comprises a series of steps as shown in Fig. 1: (i) prior and (ii) posterior inference over the latent variable, (iii) estimating the observational distribution, called a generation, and (iv) representing hidden states by recurrence

. In other words, the VRNN iteratively performs both the inference and generation process at every time step. Note that the inference step approximates the true posterior, and the generation step performs the imputation by reconstructing data from this posterior. To handle output probability distributions and stochastic training, arbitrarily flexible functions such as neural networks can be chosen.

Iii-B1 Prior

The prior distribution on the latent random variable follows the distribution ,

(3)

where is a function with learnable parameters and previous hidden states as input (Prior in Fig. 0(b)). In a clinical setting, this latent variable can be interpreted as the patient’s hidden health status at a particular time point.

Iii-B2 Inference

During the inference phase (black dashed arrows in Fig. 0(a) and Inference in Fig. 0(b)), we aim to learn the inference network that approximates the true posterior distribution over the latent variables to . We estimate the mean and log-variance using with the parameter , which is conditioned on both and as:

(4)

where is a non-linear feature extractor from . For all time steps, an approximate posterior depending on and factorizes as follows:

(5)

Here, a reparameterization trick [kingma:vae, rezende2014stochastic] is used to make the network differentiable in our implementation, as in an auto-encoding variational Bayes algorithm. We sample and then represent with and estimated from an inference network in Eq. (4), where denotes an element-wise multiplication. It should be noted that the inference step is implicitly involved in our imputation process in the context of the approximated posteriors’ contribution to the ensuing generation step.

Iii-B3 Generation

The generation step (green dashed arrows in Fig. 0(a) and Generation in Fig. 0(b)) learns the generative network following the reconstruction distribution of , the mean and log-variance of which are estimated as:

(6)

where is a generating function with parameter , and is a non-linear feature extractor from . Note that they are dependent on the posterior approximated at the inference step, as well as

. The joint distribution across subsequent time steps also factorizes as:

(7)

where and can be obtained from the prior distribution in Eq. (3) and reconstruction distribution in Eq. (6) at time , respectively.

The imputation is performed using the mean values from the given reconstruction distribution. Here, we consider a temporal decaying factor and a feature correlation of a multivariate time-series. Considering a temporal context where the influence of the medical features fades over time for the case of missing long-term values, it is well suited to introduce a decay mechanism for EHR time series modeling [che2018recurrent]. Following the aforementioned properties, the negative exponential rectifier is exploited to make the temporal decay rate monotonically decrease, which is denoted from the time interval as follows:

(8)

We learn the decay rates from the training data by learning the model parameters , rather than fixed a priori. The actual observations and the estimated mean values from the VRNN are then combined by the weight determined from the temporal decay rate and masking vector as follows:

(9)
(10)

where and are learnable parameters. This temporal decaying mechanism ensures that missing values are smoothly replaced over time. In addition, we further introduce additional feature correlations into the imputation process, where one feature is represented by a linear combination of the others

(11)

where is a parameter matrix with a diagonal of all zeros.

Consequently, the temporally decayed estimates and feature-correlated estimates are fully integrated into a combination vector using a convolution operation (

), followed by a max pooling operation that incorporates all channel information as:

(12)

Thus, depending on the presence of the observation from masking values , we can formulate the final imputation estimates as follows:

(13)

In addition, we also capture the uncertainty of the imputation estimates by using from the generative network. Here, the uncertainties of observed values are set to zero, indicating a full fidelity. It is worth noting that the imputed values are stochastically estimated with corresponding uncertainty, considering the posteriors inferred at the current time step. For the estimation of the imputed values, the uncertainties are evaluated as follows:

(14)

Iii-B4 Recurrence

During the recurrence phase (black solid arrows in Fig. 0(a) and Recurrence in Fig. 0(b)), the extracted features from and and the previous hidden state are fed into the hidden state as

(15)

where is a transformation function conditioned on both and . For the transformation function, in our study, a GRU cell [cho2014learning] is used. Notably, the imputed values and latent variables are tied through the deterministic hidden state , which further updates the inferred distribution over variables during the recurrence.

(a)
(b)
Fig. 2: Comparison of (a) original vanilla GRU cell and (b) proposed GRU-U cell.

Iii-C Uncertainty-Gated Attention

(16)
(17)
(18)
(19)

In this study, we investigate the application of uncertainty for a downstream task. For each timestamp, we efficiently propagate the uncertainty within the GRU cell gate such that the imputed data can be axiomatically incorporated with this uncertainty in a non-linear fashion. To this end, the uncertainty is additionally injected into the hidden state as input into the GRU cell with the features from , and the previous hidden state as

(20)

The basic idea of the proposed approach is to reduce the influence of the imputed data with low fidelity on the subsequent downstream tasks. To this end, our approach introduces a trainable decay designed for uncertainty in the model, and it efficiently propagates through the GRU gate units, which we call uncertainty-gated attention. Therefore, in this paper, we propose a novel GRU cell, called GRU-U that further injects uncertainty-gated attention into a vanilla GRU cell. The difference between the vanilla GRU cell and the proposed GRU-U is schematically depicted in Fig. 2. Similar to the temporal decaying rate, the attention weights are formally:

(21)

where and are model parameters that are jointly learned using other recurrence parameters. In particular, is restricted to a diagonal matrix to ensure that the decay factor of each variable is independent of the others.

In GRU-U, there are two types of gates: a reset gate and an update gate used to control the information. Compared with the vanilla GRU cell, the attention weight vector and masking vector are additionally fed into the gate by following the update computations in Eqs. (16)-(19), where

is a non-linear activation function.

It is worth noting that we learn the corresponding parameter for , thus allowing the model to learn the extent to which the attention weights are reflected in the input. In addition, the missing patterns are also modeled by directly feeding a masking vector and are linearly combined with other internal representations.

Iii-D Mortality Prediction

To predict the probability of in-hospital mortality, we use the last GRU hidden state, the representation of which is more powerful than the explicit use of the imputed data because it includes temporal information encoding across all time steps. Given the last hidden state , we apply a fully connected layer, followed by a sigmoid activation function as follows:

(22)

where sigm is a logistic sigmoid activation function and

is a parameter from a classifier. For simplicity, the recurrence, uncertainty-gated attention, and prediction model parameters are summarized as

. The overall algorithm for a unidirectional VRNN is described in Algorithm 1.

Iii-E Bidirectional VRNN

To better capture the long-term dependency of clinical records, in this study, we exploit bidirectional recurrent dynamics. In other words, given a time series from the forward direction, each time series can also be derived from the backward direction function, i.e., . The final estimated imputation is the mean of and at time

, and the prediction is based on the average logit from a forward and backward VRNN. The final loss is obtained by accumulating the forward loss

and the backward loss . This bidirectional approach helps solve the problem of inefficient training and biased explosions caused by error delays that often occur in the time sequence.

Iii-F Learning

To train all model parameters {, , }, we use a composite objective function, which consists of (i) VRNN loss , (ii) consistency loss , (iii) masked value imputation loss , and (iv) classification loss . The VRNN loss is calculated by accumulating the reconstruction error and KL divergence over the time-series for each sample as follows:

(23)

All losses for samples are gathered, which leads to . The consistency loss is defined as the difference between the forward estimation and backward estimation over time through the mean absolute error (MAE) for the consistent estimations of both the forward and backward directions. In terms of the imputation loss , we calculate masked MAE between the original sample as the ground truth and the imputed sample for the initially marked values only111Actually, we randomly selected 10% of non-missing values and removed them to investigate the masked imputation loss. Thus, another masking vector is introduced for the purpose of marking the selected values. by as:

(24)

Moreover, in this study, we address the poor classification problem from highly imbalanced data frequently found in a healthcare dataset. To accurately detect minority class observations that are often crucial (i.e., in-hospital death), we exploit a focal loss [lin2017focal]. The loss formulation is similar to the standard cross entropy loss, but reshaped such that it down-weights the loss assigned to well-classified examples in the following way:

(25)
(26)

where is the model’s predicted probability for , is a weighting factor used to balance the importance between them, and is a focusing parameter applied to focus on the minority class.

Hence, all losses are then accumulated by integrating both the forward loss and backward loss, defining the composite loss as , where , and are hyper-parameters that control the ratio between the losses. We optimize all the parameters of our model in an end-to-end manner via this composite loss.

1:, , , , , ,
2:repeat
3:     for  to  do
4:         
5:         
6:         
7:         
8:         
9:         
10:         
11:         
12:         
13:         
14:         
15:         
16:         
17:         
18:     end for
19:     
20:     
21:     
22:     
23:     
24:     
25:     
26:     
27:     
28:until Convergence
Algorithm 1 Uncertainty-gated stochastic recurrent imputation in unidirectional VRNN
Masking ratio Method MIMIC-III PhysioNet
AUC AUPRC AUC AUPRC
5% GRU-Zero 0.807 0.021 0.345 0.036 0.819 0.021 0.444 0.033
GRU-Mean 0.796 0.018 0.330 0.016 0.793 0.021 0.431 0.041
GRU-KNN 0.794 0.020 0.338 0.017 0.790 0.026 0.412 0.049
GRU-D [che2018recurrent] 0.853 0.013 0.380 0.019 0.809 0.030 0.464 0.046
M-RNN [yoon2017multi] 0.822 0.010 0.317 0.016 0.781 0.023 0.383 0.042
RITS [cao2018brits] 0.854 0.009 0.373 0.028 0.810 0.015 0.456 0.039
BRITS [cao2018brits] 0.863 0.014 0.414 0.025 0.824 0.004 0.460 0.042
SAnD [song2018attend] 0.830 0.010 0.374 0.024 0.787 0.017 0.426 0.026
Ours 0.865 0.008 0.416 0.029 0.832 0.018 0.470 0.054
10% GRU-Zero 0.809 0.023 0.346 0.024 0.822 0.022 0.440 0.039
GRU-Mean 0.795 0.011 0.326 0.020 0.793 0.018 0.436 0.028
GRU-KNN 0.798 0.015 0.328 0.021 0.788 0.021 0.428 0.041
GRU-D [che2018recurrent] 0.853 0.013 0.380 0.019 0.809 0.030 0.464 0.046
M-RNN [yoon2017multi] 0.824 0.006 0.322 0.023 0.781 0.023 0.383 0.042
RITS [cao2018brits] 0.860 0.004 0.392 0.004 0.810 0.015 0.456 0.039
BRITS [cao2018brits] 0.864 0.011 0.412 0.032 0.824 0.004 0.460 0.042
SAnD [song2018attend] 0.826 0.009 0.372 0.025 0.787 0.017 0.426 0.026
Ours 0.865 0.010 0.415 0.014 0.829 0.022 0.465 0.054
TABLE I: Performances of the mortality prediction task (mean std from 5-cross validation)

Iv Experiments

In this section, we evaluate the proposed uncertainty-gated stochastic sequential imputation method on in-hospital mortality prediction and missing value imputation task on two publicly available healthcare dataset, (i) Medical Information Mart for Intensive Care III (MIMIC-III) and (ii) PhysioNet Challenge 2012, which have multivariate time-series that include numerous missing values. To compare the results depending on the ratio of missing values, we considered two masking scenarios where 5% or 10% of the observations were additionally masked for each dataset.

We reported the performances of the mortality prediction task with the average results from a 5-fold cross validation in terms of (i) the area under the ROC curve (AUC) and (ii) the area under the precision-recall curve (AUPRC). The results of the missing value imputation are reported in terms of the MAE. Here, with respect to the mortality prediction task, both the prediction and imputation are conducted during training, although only prediction is applied during testing. In addition, regarding the imputation task, only a missing value imputation is conducted during the training and testing.

We compared the results of the two tasks with other state-of-the-art methods in the literature to show the superiority of the proposed method. In addition, we conducted extensive ablation studies of our model to evaluate the effects of different components in the proposed method.

In addition to validating the performances of the two tasks, we visualized the imputation estimates with the uncertainties predicted from our model against actual observations over time, and further investigated the behavior of our model regarding the representation of the uncertainties for the imputed estimates in terms of the correlation between imputation MAE and uncertainty.

All the codes are available at “https://open-after-acceptance”.

Iv-a Data

We used two publicly available datasets, namely, the MIMIC-III and PhysioNet challenge 2012 datasets.

Iv-A1 Mimic-Iii

We used the publicly available real-world EHR dataset, MIMIC-III222Available at https://mimic.physionet.org/., which contains longitudinal measurements for more than 40,000 critical care patients. We selected a subset of 13,998 patients with at least 48 hours of hospital stay, and sampled the time-series every 2 hours in the first 48 hours. For each patient, 99 different longitudinal measurements were selected, which were divided into four main categories: laboratory measurements, inputs to patients, outputs collected from patients, and drug prescriptions. The selected time series were scarcely observed leading to a missing rate of approximately 93.92%. For the in-hospital mortality label, the ratio between 1,181 positive (dead in hospital) and 12,817 negative (alive in hospital) was approximately 1:10.8.

Iv-A2 PhysioNet challenge 2012

We also used the PhysioNet challenge 2012 dataset333Available at https://physionet.org/content/challenge-2012/1.0.0/., which contains longitudinal measurements for 4,000 critical care patients with at least 48 hours of hospital stay. Here, we removed 3 patients from original dataset who had no observations at all. We sampled the observations hourly in the first 48 hours, taking the mean value of multiple observations within one hour. For each patient, 35 different longitudinal measurements were exploited, including the time-series measurements of vital signs and lab test results. The time-series data contain a large number of missing values with a missing rate of approximately 80.51% and an in-hospital mortality label imbalanced at a ratio of approximately 1:6 between 554 in-hospital deaths and 3,443 survival cases.

Iv-B Preprocessing and Training

For the MIMIC-III dataset, data cleaning was conducted by handling inconsistent units, multiple recordings made at the same time, and the range of the recorded feature values. We referred to [che2018recurrent, purushotham2018benchmarking]

for feature selection, data cleaning, and preprocessing.

For all datasets, because each variable has a different range, all inputs were first Winsorized for removing outliers and then

-normalized using the global mean and standard deviation from the entire training set to achieve a zero mean and unit variance in a variable-wise manner, as described in

[bahadori2019temporal].

We trained our models using the Rectified Adam (RAdam) optimizer [liu2019variance] with an initial learning rate of and a multiplicative decay of 0.5 for epochs using mini-batches of samples. We chose the final optimal model based on the performance of the validation set.

Iv-C Model Implementations

The VRNN comprises an inference network, a generative network, feature extractors, and an RNN, which are built using neural networks in our implementation. The inference and generative network were fully connected with 2 hidden layers that are a linear operation followed by batch normalization, and a rectified linear unit (ReLU) activation, where the dimensions of the latent variables are 32 and 16 for each layer, respectively. For the feature extractors, we constructed a single hidden layer using a Tanh activation to extract complex non-linear features. For the RNN, a single layer with 64 GRU hidden units was employed with a Tanh activation function. The

and in the focal loss were chosen to be 5 and 0.25, respectively. The ratio in the composite loss was set to , and as a result of varying their values in [0, , , , , , , 1].

All model parameters were initialized as small random numbers such that their values fell within the standard deviation interval, which is the inverse of the number of input nodes.

Iv-D Baseline Methods

We validated the efficacy of our framework by dividing the evaluation of (i) the in-house mortality prediction task and (ii) the missing value imputation task. Regarding the mortality prediction task, we compared our proposed method with the vanilla GRU with the zero, mean, and KNN imputation444 The Mean Impute (SimpleFill) and KNN are implemented by using fancyimpute library in Python. The code is publicly available at https://github.com/iskandr/fancyimpute. (i.e., GRU-Zero, GRU-Mean, and GRU-KNN); RNN-based SOTA models such as GRU-D [che2018recurrent], M-RNN [yoon2017multi], BRITS and RITS [cao2018brits], removing the backward direction in BRITS; and a transformer-based SOTA model, SAnD [song2018attend], which employs a masked self-attention mechanism for clinical diagnosis.

For a missing value imputation, we included Zero Impute, Mean Impute, KNN, GRU-D [che2018recurrent], M-RNN [yoon2017multi], BRITS, and RITS [cao2018brits]. The prediction tasks for all datasets are compared in Table I, and that of the imputation task is compared in Table II.

V Results

V-a Result of Mortality Prediction

Table I compares the results of our proposed method with those of the baselines for mortality prediction. In both masking scenarios, our model achieved the best classification performance on both datasets. The results of a relatively simple imputation method with a GRU (GRU-Zero, GRU-Mean, and GRU-KNN) suggest the need for a more sophisticated imputation method compared with the other SOTA models. Among the RNN-based baselines, BRITS demonstrated a competitive performance. In contrast, SAnD showed a relatively low performance compared with the other RNN-based methods, despite the benefits of computational efficiency. These experimental results validate the efficacy of our proposed method equipped with the stochastic recurrent imputation using the VRNN and GRU-U cell, showing its superior performance in the downstream task.

Masking ratio Method MIMIC-III PhysioNet
5% Zero Impute 0.724 0.005 0.788 0.006
Mean Impute 0.520 0.003 0.510 0.010
KNN 0.508 0.003 0.396 0.005
GRU-D [che2018recurrent] 0.584 0.007 0.660 0.014
M-RNN [yoon2017multi] 0.451 0.008 0.411 0.011
RITS [cao2018brits] 0.354 0.005 0.325 0.007
BRITS [cao2018brits] 0.332 0.005 0.297 0.008
Ours 0.497 0.012 0.525 0.004
10% Zero Impute 0.724 0.005 0.791 0.004
Mean Impute 0.523 0.006 0.513 0.004
KNN 0.515 0.004 0.415 0.004
GRU-D [che2018recurrent] 0.582 0.006 0.658 0.016
M-RNN [yoon2017multi] 0.438 0.005 0.397 0.009
RITS [cao2018brits] 0.331 0.006 0.308 0.006
BRITS [cao2018brits] 0.312 0.004 0.283 0.008
Ours 0.503 0.011 0.526 0.010
TABLE II: Results of missing value imputations measured by MAE score

V-B Result of Missing Value Imputation

Table II compares other imputation baselines for the missing value imputation task. Whereas BRITS showed the lowest MAE scores under two masking scenarios on both datasets, the MAE scores of our proposed method are slightly higher, comparable to those of Mean Impute, and in most cases, better than those of Zero Impute, Mean Impute, KNN, and GRU-D.

(a)
(b)
Fig. 3: Visualization of imputation estimates for PhysioNet dataset under both (a) 5% and (b) 10% masking scenarios. Blue filled dots are observed measurements, blue lines blue shades with hollow dots are the imputations and uncertainties estimated from our model, respectively, and red x marks are masked ground-truth observations. In addition, the number above the red dashed vertical line represents the exact MAE values between masked ground-truth observation values and model predictions.
Method AUC AUPRC
VAE + vanilla GRU 0.759 0.007 0.373 0.019
VAE + GRU-U 0.813 0.017 0.442 0.029
VRNN + vanilla GRU 0.794 0.013 0.411 0.038
VRNN + GRU-U 0.832 0.018 0.470 0.054
TABLE III: Performance of a set of ablation experiments for mortality prediction task
Ablation Method AUC AUPRC
Loss BCE loss 0.776 0.014 0.384 0.032
Focal loss 0.832 0.018 0.470 0.054
Full 0.829 0.016 0.472 0.035
Diagonal 0.832 0.018 0.470 0.054
TABLE IV: Performance of the ablation experiments related to loss and for the mortality prediction task
(a)
(b)
Fig. 4: Visualization of uncertainty values depending on MAE values for PhysioNet dataset under both (a) 5% and (b) 10% masking scenarios. Each pink dot represents a pairing of the MAE and the uncertainties for masked observations from a single sample, and the slope of the grey line represents the correlation coefficient.

V-C Ablation Studies

In this study, we conducted a set of ablation experiments to investigate the influence of different experimental design options of our method, evaluated on the PhysioNet dataset under the 5% masking scenario.

  • Effect of the VRNN: To validate our stochastic recurrent imputation, we compared the case of using a VAE and an RNN separately. Whereas the VAE only considers the dependencies among variables during the imputation process, the VRNN generates a time-series considering the variable and temporal-based dependencies simultaneously. Table IV summarizes the classification performances of two masking scenarios, (VAE+vanilla GRU) [jun2019] versus (VRNN+vanilla GRU) and (VAE+GRU-U) versus (VRNN+GRU-U). Interestingly, the performances improved under both scenarios, particularly when using vanilla GRU with a large margin. Although both approaches provide a stochastic imputation, the experimental results highlight the effectiveness of simultaneously modeling the underlying temporal dependencies and reconstructing the imputation estimates.

  • Effect of the GRU-U: To verify the effectiveness of our proposed GRU-U cell, we compared the experimental results using the vanilla GRU cell and GRU-U cell. From the prediction results of (VAE+vanilla GRU) versus (VAE+GRU-U) and (VRNN+vanilla GRU) versus (VRNN+GRU-U) in Table IV, we noticed that both results show significant performance improvements; however, in particular, leveraging the GRU-U cell with the VAE is relatively more critical to the predictive task. These experimental findings validate the efficacy of uncertainty-gated attention.

  • Effect of the focal loss: The binary cross-entropy (BCE) loss is widely used in binary classification. By comparing the performance of BCE loss with that of the focal loss in Table IV, we found that the focal loss obtained a better performance than the BCE loss in terms of both AUC and AUPRC. Thus, we can conclude that the focal loss is sufficient for capturing the minority samples in our prediction task.

  • Weight matrix in uncertainty-gated attention: The method for propagating uncertainty is determined by , where the diagonal matrix effectively makes the decay rate of each variable independent from the others and a full weight matrix makes it dependent. As shown in Table IV, AUPRC was higher when using the full weight matrix, whereas AUC was higher when using the diagonal matrix. This suggests that the dependencies of the medical variables in the calculation of uncertainty-gated attention do not show a significant difference in the performance of the downstream tasks.

Vi Discussion

Vi-a Visualization of Imputation Estimates

Fig. 4 compares actual observations and our model predictions with their uncertainties on the PhysioNet dataset under both 5% and 10% masking scenarios. Our model tends to produce temporally smooth curves and also exhibits different levels of uncertainty over the time-series. It is noteworthy that the uncertainty estimates correlate qualitatively with the missingness of the features and the noise levels of the observations. This helps clinicians make informed decisions regarding the fidelity they should have in the model.

Vi-B Analysis of Model Behavior

In the results of the missing value imputation described in Section V-B, our proposed method showed slightly higher MAE scores. However, it should be noted that the ultimate goal of our study is to correctly predict the mortality by leveraging a missing value imputation. Further, owing to the noisy observations in practice, we explicated the uncertain noisy factors in imputing the missing values and exploited such uncertain factors into our prediction model. Thus, whereas the MAE score of our method is higher than that of the competing methods, by better reflecting the noisy factors or imputation values in terms of uncertainty, we could achieve a better mortality prediction performance, which is imperative in a clinical setting.

Hence, we further investigated the behavior of our model to represent the uncertainties for the imputation estimates. For this, we calculated the Pearson’s correlation between the MAE and the uncertainty values using a testset of the PhysioNet dataset under both 5% and 10% masking scenarios. For each data instance, we obtained the MAE between the ground truth and predicted imputations with the corresponding uncertainties.

Fig. 4 shows a scatter plot of the uncertainty values depending on the MAE values under 5% and 10% masking scenarios, respectively. For the masking scenario, the average correlation coefficient is 0.406 with

for null hypothesis that there is no correlation between MAE and uncertainty, and for the

masking scenario, the correlation coefficient is 0.338 with . These experimental results indicate that our model can provide sufficient information regarding the fidelity of the model by largely predicting the uncertainty, depending on the MAE value, even if the estimated imputations are far from the actual observation.

Vii Conclusion

In this work, we proposed a novel uncertainty-gated stochastic sequential imputation method that extends the VRNN for mortality prediction with EHR data. By leveraging VRNN as our base model, we handle representation learning, missing value imputation, and in-hospital mortality prediction in a single stream. In addition, we proposed the novel GRU cell, in which temporal information encoding in hidden states is updated by propagating the uncertainty for the inferred distribution over the variables. We validated the effectiveness of our method over the public MIMIC-III and PhysioNet challenge 2012 datasets by comparing with and outperforming to the state-of-the-art methods considered in our experiments for mortality prediction. Furthermore, we identified the behavior of the model that well represented the uncertainties for the imputed estimates, which indicated a high correlation between the calculated MAE and the uncertainty.

Acknowledgment

This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) (No. 2017-0-00053, A technology development of artificial intelligence doctors for cardiovascular disease, and No. 2019-0-00079, Department of Artificial Intelligence (Korea University)).

Data Availability

MIMIC-III database analyzed in this study is available on PhysioNet repository.

References