Deep Ensemble Tensor Factorization for Longitudinal Patient Trajectories Classification

by   Edward De Brouwer, et al.

We present a generative approach to classify scarcely observed longitudinal patient trajectories. The available time series are represented as tensors and factorized using generative deep recurrent neural networks. The learned factors represent the patient data in a compact way and can then be used in a downstream classification task. For more robustness and accuracy in the predictions, we used an ensemble of those deep generative models to mimic Bayesian posterior sampling. We illustrate the performance of our architecture on an intensive-care case study of in-hospital mortality prediction with 96 longitudinal measurement types measured across the first 48-hour from admission. Our combination of generative and ensemble strategies achieves an AUC of over 0.85, and outperforms the SAPS-II mortality score and GRU baselines.



There are no comments yet.


page 1

page 2

page 3

page 4


Dynamic Prediction of ICU Mortality Risk Using Domain Adaptation

Early recognition of risky trajectories during an Intensive Care Unit (I...

Longitudinal modeling of MS patient trajectories improves predictions of disability progression

Research in Multiple Sclerosis (MS) has recently focused on extracting k...

Neural Document Embeddings for Intensive Care Patient Mortality Prediction

We present an automatic mortality prediction scheme based on the unstruc...

An Ensemble Classifier for Predicting the Onset of Type II Diabetes

Prediction of disease onset from patient survey and lifestyle data is qu...

Medical Diagnosis From Laboratory Tests by Combining Generative and Discriminative Learning

A primary goal of computational phenotype research is to conduct medical...

How to Assess the Impact of Quality and Patient Safety Interventions with Routinely Collected Longitudinal Data

Measuring the effect of patient safety improvement efforts is needed to ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Envisioned as key tool for precision medicine, the computational analysis of patient trajectories has recently been a major focus of interest [7; 18; 16]. In particular, patient trajectories are expected to be of great importance in diseases that evolve over long time periods (e.g., chronic diseases such as diabetes or multiple sclerosis) or with highly patient-specific medical trajectory patterns (e.g., Intensive Care Unit (ICU) data). However, despite the increasing literature on medical health records mining [17], there is still a lack of methods for modeling this particularly challenging type of data in a natural way.

Patient trajectories usually consist of complex sets of longitudinal measurements (e.g., blood glucose), medical events (e.g., onset of comorbidities), or patient covariates (e.g., gender). They are relevant for a broad range of medical tasks, such as (1) patient segmentation [22; 11], (2) prognosis [20], or (3) treatment optimization [13]. However, their statistical analysis is challenging because of the inherent properties of the data [9; 19]. First, the temporal series are often scarcely observed (i.e., only a few percent of possible instances are actually measured) and irregularly sampled (e.g.

, a multiple sclerosis patient typically has a medical visit every 6 months). Second, the observation pattern is informative in itself as it reflects, among other things, the need for additional or fewer medical visits, and thus the state of the patient. Lastly, observations are noisy and because of the complexity of the diseases, there is seldom a natural way to align patients on a common time scale. Those limitations typically prevent us from directly feeding this type of data into classical supervised learning methods.

In this work, we propose a modeling strategy that addresses the aforementioned issues in a direct way. We consider that the observed measurements are generated by a low-dimensional hidden temporal process that summarizes the health state of the patient at each point in time. The noisy observations are then interpreted as outputs, which results in a natural way of handling missing values. For modeling complex temporal dependencies in the data, and to address the trajectories alignment issue, we rely on a recurrent neural network architecture to generate the observations. We further boost the performance of our model by using an ensemble method to approximate the Bayesian sampling of the posterior of the predictions.

2 Previous work

The machine learning community has recently started to address the challenging problem of patient trajectory modelling

[14; 17; 15; 4; 6; 12; 21]. Choi et al. [3] proposed Doctor AI, a GRU-based architecture to predict medical event at the next visit. In the specific task of patient trajectory classification, Lipton et al. [10]

proposed a LSTM-based model fed by imputed data concatenated with an observation mask. The closest work to our approach is the one of

Che et al. [2] who designed an extended GRU-cell for dynamic imputation mechanism that is trained for classification only. In contrast, we propose a model that, in addition to classification, generates the observed trajectories. Others approaches also include convolutional networks [5]. Several recent works have also focused on the specific task of in-hospital mortality prediction in intensive care [2; 1].

3 Methods

3.1 Data representation: Tensorization

The data we aim at analyzing typically consists of multiple longitudinal measurements for each patient together with their time labels. We first discretize the time into bins with high granularity resulting in minimal loss of information. Measurements falling in the same bin are either averaged or summed depending on the specific measurement type.

We then represent the types of temporal medical measurements of patients over time steps as an order-3 tensor of dimension . In the applications we focus on, this tensor is usually scarcely observed, resulting in a low fill rate of just a few percents.

On top of longitudinal measurements, static information about patients is also available. We write the matrix of static patient covariates as of dimensions with the number of covariates. This matrix is assumed to be fully observed. That is, all covariates are known for each patient.

Finally, each patient is assigned a label that defines its class. We write

the vector of class labels of all patients.

3.2 Model definition

The learning objective is to correctly classify the patient labels using temporal information and static information . For this purpose, we proceed in two joint tasks: deep factorization of the tensor and classification of the patients based on the retrieved latent factors.

Regarding the factorization task, we assume that the temporal (tensor) observations of each patient are generated by a -dimensional latent process and the corresponding measurement functions :


This individual latent process can be interpreted as the hidden health status of the patient, conditioned on which the observations are generated. The model decomposes the tensor into temporal patient specific factors and measurements specific functions . We assume the first latent factors are generated from the static covariates (with generating function and noise ) and that these factors then transition over time according to an unknown process and noise :


We further assume that the binary patients labels are also generated by their last hidden health process value through a mapping . Specifically, .

3.3 Inference and learning

As we want to account for complex generating functions, exact inference of the latent temporal process is intractable. We therefore adopt a recurrent neural network approach that computes an approximate inference of the patient hidden process.

The first latents are generated from a nonlinear mapping of the static covariates as shown in (2). We then feed a GRU-based network with the longitudinal observations concatenated with observations mask . When the samples are not observed, we impute the missing observations with the predictions of the previous time steps.

Specifically, at each time step , for patient , we feed the GRU cell with a vector such that

where if sample is observed and otherwise. The network is then fed with the observation pattern and we let the GRU design its own imputation strategy when the sample is not available. Note that we design the network such that it generates the observations based on the latents at each time step as in (1). We then jointly train a classifier on the last hidden vector of each patient for the labels.

Two objectives coexist in this model architecture: the ability to reconstruct the observed patient trajectories based on the hidden process and the classification performance. We train our model using the mixing of those two goals with a hyperparameter

. The overall loss to optimize is then

where stands for the cross entropy loss and are the weights of the model. A visual representation of the network architecture is presented in Figure 1. The generative approach mainly presents two advantages: the latent vectors are imposed to be representative of the whole observed trajectory and they provide a more natural way to deal with missing input samples.

Figure 1: Model architecture. Rectangles represent GRU cells unfolded over time. The network both generates class labels and the observed temporal trajectories.

For better predictive accuracy, we use a performance driven ensemble learning method to approximate Bayesian posterior sampling of the predictions. Practically, we train a large number of models with hyperparameters sampled from some prior distribution. We then select the best models according to validation performance and average their predictions.

4 Experiments

4.1 Case study definition

We apply our methodology to an intensive care unit case study. We used the publicly available MIMIC III data set that contains longitudinal measurements for more than 40,000 critical care patients [8]. The objective of our case study is to predict, based on longitudinal data of a patient in a 48 hours observation window, in-hospital mortality for that individual patient. We selected a subset of 20,000 patients with at least 48 hours of hospital stay. For each patient, we selected 96 different longitudinal measurements types divided in 4 main categories: lab measurements, inputs to patients, outputs collected from patients, and drug prescriptions. The full list is available in the supplementary material. The selected time series were scarcely observed leading to a filling rate of 5.9% for tensor . We also selected static patient covariates (), such as age, admission type, and main ICD10 diagnose for the admission.

4.2 Our models and baselines

We consider the following baselines: the SAPS-II111We restricted the SAPS-II severity score to the variables that were available in the data subset under consideration. severity score and a missingness-informed GRU baseline with smart imputation. We call this architecture GRU-imputed. In contrast to our proposed method, this architecture is trained for classifying the in-hospital mortality labels only. It therefore does not learn any hidden process generating the required observations. At each time step GRU-imputed is fed with a vector of observations (with missing values imputed to their means as suggested by Lipton et al. [10] and Che et al. [2].) concatenated with an observation mask and the elapsed time since last observed sample. The SAPS-II is a static severity score widely used in clinical practice.

We then trained 200 of our models with hyperparameters sampled from the following priors: and on the training set. We then ranked the models based on their performance on the validation set, selected an ensemble of the best 20, averaged their predictions and report the performance on an held out test set. The same data splitting was used to tune and evaluate all models.

4.3 Performance

The performance of both baselines and our method is presented on Figure 1(a). Our methodology outperforms the proposed baselines. Furthermore, we notice an increase from 0.842 to 0.855 in AUC due to the ensemble strategy. Impact of the number of models in the ensemble is presented in Figure 1(b). We observe that few models are required to obtain significant performance improvement.

(a) ROC curves and AUC for the different models
(b) Impact of the ensemble strategy on the performance
Figure 2: Results of the models considered

5 Conclusion

We propose to consider multivariate longitudinal patient trajectories as a higher-order tensor that is factorized using deep recurrent neural networks. The temporal factors of each patient are then used for in-hospital mortality prediction. The performance of the proposed architecture shows that the generative approach outperforms both classification-only and static models.


6 Supplementary material

6.1 Retained features for the intensive care case study

In table 1, we present the longitudinal features retained for the training of our models.

width= Retained Features Lab measurements Inputs Outputs Prescriptions Anion Gap Potassium Chloride Stool Out Stool D5W Bicarbonate Calcium Gluconate Urine Out Incontinent Docusate Sodium Calcium, Total Insulin - Regular Ultrafiltrate Ultrafiltrate Magnesium Sulfate Chloride Heparin Sodium Gastric Gastric Tube Potassium Chloride Glucose K Phos Foley Bisacodyl Magnesium Sterile Water Void Humulin-R Insulin Phosphate Gastric Meds TF Residual Aspirin Potassium GT Flush Pre-Admission Sodium Chloride 0.9% Flush Sodium LR Chest Tube 1 Metoprolol Tartrate Alkaline Phosphatase Furosemide (Lasix) OR EBL Asparate Aminotransferase Solution Chest Tube 2 Bilirubin, Total Hydralazine Fecal Bag Urea Nitrogen Midazolam (Versed) Jackson Pratt 1 Basophils Lorazepam (Ativan) Condom Cath Eosinophils PO Intake Hematocrit Insulin - Humalog Hemoglobin OR Crystalloid Intake Lymphocytes Morphine Sulfate MCH D5 1/2NS MCHC Insulin - Glargine MCV Metoprolol Monocytes OR Cell Saver Intake Neutrophils Dextrose 5% Platelet Count Norepinephrine RDW Piggyback Red Blood Cells Packed Red Blood Cells White Blood Cells Phenylephrine PTT Albumin 5% Base Excess Nitroglycerin Calculated Total CO2 KCL (Bolus) Lactate Magnesium Sulfate (Bolus) pCO2 pH pO2 PT Alanine Aminotransferase Albumin Specific Gravity

Table 1: Retained longitudinal features in the intensive care case study.