Improving Clinical Predictions through Unsupervised Time Series Representation Learning

12/02/2018 ∙ by Xinrui Lyu, et al. ∙ 0

In this work, we investigate unsupervised representation learning on medical time series, which bears the promise of leveraging copious amounts of existing unlabeled data in order to eventually assist clinical decision making. By evaluating on the prediction of clinically relevant outcomes, we show that in a practical setting, unsupervised representation learning can offer clear performance benefits over end-to-end supervised architectures. We experiment with using sequence-to-sequence (Seq2Seq) models in two different ways, as an autoencoder and as a forecaster, and show that the best performance is achieved by a forecasting Seq2Seq model with an integrated attention mechanism, proposed here for the first time in the setting of unsupervised learning for medical time series.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Patient representation learning is one of the popular topics in the field of machine learning for healthcare. The generality of supervised representations is usually constrained by the amount of labeled data, while unsupervised representations can leverage information from all data, labeled or not. Hence, unsupervised learning can produce representations of

general utility dosovitskiy2014unsupervisedForImages ; mikolov14doc2vec ; mikolov13word2vec ; miotto16_deep_patient_unsupervised , which can be useful in case downstream tasks are not known a priori.

Conditions like the ones described above are especially true in the medical domain. Routine medical practice generates a wealth of patient-related time series, while data annotation often requires medical experts, whose time is very limited. Additionally, new tasks of interest emerge, and different hospitals or health systems often define tasks in different ways. Thus, generally useful representations, providing good performance over a broad range of downstream tasks, are highly desired.

In this work, we investigate unsupervised representation learning on medical time series, which remains relatively unexplored. We propose adapted and novel models well suited for this objective and elucidate under which conditions they provide a performance benefit over end-to-end supervised learning with respect to predicting clinically relevant outcomes.

2 Related Work

The unsupervised learning approaches studied in this paper are rooted in the autoencoding principle bengio2013representation . The basic autoencoding architecture has been extended in several ways, such as denoising vincent2010stacked , variational kingma2013auto , convolutional masci2011stacked , or contractive rifai2011contractive autoencoders. Sequence-to-sequence (Seq2Seq) sutskever2014sequence architectures have been used successfully in translation weiss2017sequence , and on text and images chen2015mind ; gregor2015draw . Seq2Seq models have also been pre-trained in an unsupervised way ramachandran2016unsupervised and fine-tuned with labeled data.

Several models for unsupervised representation learning have been successfully employed in medical applications pivovarov2015learning ; miotto16_deep_patient_unsupervised ; suresh17_use_autoencoders_discovering ; jones16_canonical_correlation_analysis ; choi2016multi

. While in many cases representations were obtained with both descriptive as well as predictive utility, the optimal reconstruction principles and loss functions leading to accurate clinical outcome prediction have not been widely studied.

Attention mechanisms can improve performance and interpretability and have enjoyed wide use across domains chorowski2015attention ; xu2015show ; kumar2016ask ; choi2016retain . Although attention has been used in the context of unsupervised representation learning of natural language jang2018RNNSVAE , attention architectures in the medical domain have been so far exclusively focused on predicting specific supervised tasks.

3 Representation Learning Models

3.1 Baselines: Autoencoders

Autoencoding consists of two steps: encoding maps the input data space to an representation space , where typically , while decoding maps in the reverse direction to reconstruct the data from representations. The objective of autoencoding is to minimize the reconstruction error between the input data and the reconstructions.

Non-Sequential models

Principal Component analysis (PCA) and its inverse together can be considered as a simple autoencoding process, where the encoding is a learned linear projection. An autoencoder (AE) is a neural network composed of an encoder and a decoder, each implemented as a multi-layer perceptron; it encodes the data in a non-linear way. Our goal is to encode temporal sequences of physiological signal vectors, but the inherent architecture of PCA and AE does not allow them to exploit the temporal structure in time series. To make data compatible with the input format of PCA and AE, we flatten a -dimensional time series (i.e. time samples, each of dimensions) into a -dimensional vector.

Seq2Seq model

While Seq2Seq models are often used in supervised training settings in natural language processing 

sutskever2014sequence ; ramachandran2016unsupervised ; weiss2017sequence , we use it in an unsupervised way by minimizing the input reconstruction error as an objective; we refer to such a model as a S2S-AE. Figure 1

shows the structure of a S2S-AE model. A Long Short-Term Memory (LSTM) cell is used for both encoder and decoder recurrent neural network (RNN) units, because it can retain information over more time-steps compared to simple RNN cells

hochreiter1998vanishing ; hochreiter2001gradient .

Figure 1: Sequence-to-sequence autoencoder (S2S-AE). is the length of history the we want to encode in the representations. If we want to encode the patient’s history from admission to time , then .

At time , the encoder receives a sequence of signal vectors from a time window of size as input and produces a representation , where is the last hidden state of the encoder. The decoder, given , outputs a sequence of reconstructions for the same window. Let and denote the encoder and decoder respectively, with parameters and . Then the S2S-AE model can be formulated like

(1)
(2)

where is the average reconstruction error for one window of a single patient’s input signals from until the current time . The loss for patient is then the average error over their windows, indexed by

, sliding with stride 1. To train the S2S-AE model we average the patient-wise loss over all

patients. The representation from a S2S-AE model summarizes a fixed length of the medical history of a patient up to time , which reflects the current state of the patient.

3.2 Sequential forecasting model (S2S-F)

We hypothesize that the requirement to forecast future time points in the patient’s signal would force the encoding LSTM to extract meaningful representations of the past time series. For this purpose, we design another Seq2Seq-based variant, S2S-F (“F” for forecasting), where the decoder predicts the future time series instead of reconstructing the past time series in the input. In this way, the representations still reflect the current patient state but are also optimized to predict the future patient state. We modify (1) and (2) to get the decoder function and the loss function for S2S-F:

3.3 Forecasting with attention (S2S-F-A)

The idea behind applying attention mechanisms to time series forecasting is to enable the decoder to preferentially “attend” to specific parts of the input sequence during decoding. This allows for particularly relevant events (e.g. drastic changes in heart rate), to contribute more to the generation of different points in the output sequence. Since autoencoding with attention is trivial (an effective attention mechanism would learn to only point to the corresponding input at each time point), we only augment S2S-F with the attention mechanism, calling the architecture S2S-F-A (shown in Figure 2).

Figure 2: Sequence-to-Sequence Forecaster with Attention (S2S-F-A).

Formally, at time during decoding, the objective is to produce a context vector which is a weighted combination of the hidden states of the decoder: The weights are softmax-normalized versions of weights computed by the attention mechanism , which considers both the current state of the decoder and each state of the encoder in turn: and . To implement , we use a single-layer perceptron with a tanhactivation function and scalar output, following luong2015effective :

Each reflects the importance of time point in the input sequence for decoding time point in the output. The context vector is thus an explicit resummarization of the input data in light of the current decoding task. The context vector is concatenated to the usual input fed to the decoder at , which is (see Figure 2).

The attention mechanism breaks the "bottleneck" principle of usual Seq2Seq models, and it is not obvious how to choose a self-contained representation. Following our practice for S2S-AE and S2S-F, we take the final state of the encoder, as the representation. Although we experimented with additionally including context vectors as part of the representation, an interesting finding was that simply taking was sufficient in the prediction of downstream tasks. Table A2 summarizes the characteristics of the unsupervised representation models we analyze.

4 Experiments and results

Data

The eICU Collaborative Research Database v1.2 (goldberger00_physiobank_physitoolkit_physionet, )

was used for all experiments described in this paper. 94 time series variables including periodic and aperiodic vital signs and irregularly measured lab tests were extracted. The data was resampled to be hourly, with implausible data rejection and imputation performed online; see Appendix 

A.1 for more details. Overall, the dataset consists of 20,878 patients with 72-720 hours of history, extending from ICU admission to dispatch. We use a window size of 12 hours (i.e. 12 time points) and representation dimension 94.

Reconstructing past and predicting future

We aim to evaluate the ability of representations to reconstruct past and future data. Some representations are obtained from models optimized to reconstruct past data (PCA, AE and S2S-AE), while others from models optimized to predict future data (S2S-F and S2S-F-A). To produce a fair comparison independent of a specific decoder, we use the representations themselves as input features to a 1-layer LSTM trained to either reconstruct the past 12 hours, or predict the next 12. The performance for each set of representations are shown in Table 1, evaluated using mean-squared error (MSE). Not surprisingly, representations from forecaster models perform better in future prediction and the attention mechanism further improves performance. However, the extent to which attention helps is surprising.

MSE PCA AE S2S-AE S2S-F S2S-F-A
Reconstruction
Prediction
Table 1: Performance of representations used as input features to a 1-layer LSTM trained to either reconstruct the past 12 hours, or predict the next 12 hours. (The best results are in bold and second are best marked with *)
Predicting mortality and discharge status within the next 24 hours

Besides evaluating the ability of representations in past/future signal prediction, we are also interested in whether we can use them to predict future clinical events. Here we focus on predicting whether patients will be discharged from the ICU in a stable state (“24h Discharge”), or die within the next 24 hours (“24h Mortality”). We trained 1-layer LSTM classifiers (LSTM-1) using representations as input to predict these two events and report the area under ROC curve (AUROC) and the area under precision-recall curve (AUPRC) in Table 

2. In addition, we also include the performance of a 3-layer LSTM classifier (LSTM-3), a “deeper” model, trained on the original input signals as a baseline.

24h Discharge 24h Mortality
AUPRC AUROC AUPRC AUROC
LSTM-1 + PCA rep.
AE rep.
S2S-AE rep.
S2S-F rep.
S2S-F-A rep.
LSTM-3 + raw signals
Table 2: Prediction of discharge/mortality status within the next 24 hours using unsupervised representations or raw signals. The prevalence of the discharge and mortality positive labels is 0.197 and 0.021 respectively. (The best results are in bold and second are best marked with *)
Improved performance in limited data setting

Here we evaluate how unsupervised representations help boost prediction performance in the limited labeled data scenario.

We simulate this setting by reducing the quantity of labeled data available for the classification problems described in the previous section, with as few as (N = 75 patients) training examples. The results under this varying data scarcity are shown in Figure 3, for the different representation-learning approaches. We also include the prediction performance of classifiers, namely LSTM-1 and LSTM-3, trained in an end-to-end supervised fashion on the available labeled data, as baselines.

Figure 3: 24h discharge and mortality prediction performance of LSTM-1 using unsupervised representations, as well as supervised learning with LSTM-1 and LSTM-3. (There are only 75 labeled patients for training in the 1% labeled data setting.)

We observe from Figure 3 that when labels are scarce, the model trained using time-series representations as input features outperforms the end-to-end supervised model, confirming the benefit of unsupervised representation learning in limited data settings. Even when we use all labeled samples at our disposal to train a more complex classifier, the best unsupervised representations still lead to a better performance than supervised representations. For all models, however, performance does not saturate when increasing the training set size, which indicates that the entire regime examined here is the data scarcity regime. Given more data, the purely supervised models might eventually surpass the ones using learned representations.

5 Conclusion

We have studied the performance of several methods for learning unsupervised representations of patient time series, and proposed a new architecture, S2S-F-A, which is optimized for forecasting using an attention mechanism. We empirically showed that in scenarios where labeled medical time series data is scarce, training classifiers on unsupervised representations provides performance gains over end-to-end supervised learning using raw input signals, thus making effective use of information available in a separate, unlabeled training set. The proposed model, explored for the first time in the context of unsupervised patient representation learning, produces representations with the highest performance in future signal prediction and clinical outcome prediction, exceeding several baselines.

References

Appendix A Appendix

a.1 Data

The eICU Collaborative Research Database v1.2 [27] was used for all experiments described in this paper. 94 time series variables (shown in Table A1

) including periodic and aperiodic vital signs and irregularly measured lab tests were extracted from the raw database. A variable was included in our analysis if at least 10% of patients in the cohort had at least one record for this variable. As preprocessing, the raw data was resampled to a regular time-grid format with an interval size of 60 minutes, extending from admission to the ICU to dispatch from the unit. During computation of the time grid, rejection of implausible data and imputation were performed with an online algorithm. An observation was rejected if it is a statistical outlier with respect to pre-computed 5th/95th dataset percentiles. Values on the regular time grid were imputed using a combination of forward filling, personalized history mean filling and population median filling. Forward filling was used if the last value was recorded no earlier than 1 hour (periodic vital signs), 5 hours (aperiodic vital signs) or 1 day (lab tests) prior to the grid point, respectively. Otherwise, if there have been previous observations of that variable, the mean of all such observations was used to fill in the time grid point. If there were no observations in a patient’s history, the grid value was filled with the population median for that variable.

Overall, the dataset consists of 20878 patients with 72-240 hours of history.

eICU Table Variables
vitalPeriodic cvp, heartrate, respiration, sao2, st1, st2, st3, systemicdiastolic, systemicmean, systemicsystolic, temperature
vitalAperiodic noninvasivediastolic, noninvasivemean, noninvasivesystolic
Lab -bands, -basos, -eos, -lymphs, -monos, -polys, ALT (SGPT), AST (SGOT), BNP, BUN, Base Deficit, Base Excess, CPK, CPK-MB, CPK-MB index, Carboxyhemoglobin, Fe, Ferritin, FiO2, HCO3, HDL, Hct, Hgb, LDL, LPM O2, MCH, MCHC, MCV, MPV, Methemoglobin, O2 Content, O2 Sat (%), PT, PT - INR, PTT, RBC, RDW, Respiratory Rate, TIBC, TSH, TV, Total CO2, Vancomycin - trough, Vent Rate, Vitamin B12, WBC x 1000, WBC’s in urine, albumin, alkaline phos., ammonia, anion gap, bedside glucose, bicarbonate, calcium, chloride, creatinine, direct bilirubin, fibrinogen, glucose, ionized calcium, lactate, lipase, magnesium, pH, paCO2, paO2, peep, phosphate, platelets x 1000, potassium, sodium, temporature, total bilirubin, total cholesterol, total protein, triglycerides, troponin - I, troponin - T, urinary sodium, urinary specific gravity
Table A1: List of 94 selected variables.

a.1.1 Cohort selection

Among the >200,000 ICU stays available in the dataset, we included only patients with one stay, such that data splits do not have to be stratified with respect to patient ID. In the second filtering step, ICU stays shorter than 3 days or longer than 10 days were excluded. The filtering yielded a set of 20878 patients/ICU stays.

a.1.2 Data splits

From the pre-filtered dataset we created 5 replicates of random partitions into train, validation and 2 test sets, with respect to patients, i.e. the entire data of a patient was contained in exactly one of the 4 sets. Size ratios of 40:40:10:10 for train/validation/test1/test2 sets were used. The training set was used to train the representations, the validation set was used to tune free hyperparameters of the representation method (if any). The classifiers were trained on the patient representations obtained from the validation set, optimized its hyperparameters on the representations from the first test set, and its predictive performance was evaluated on the unseen representations from the second test set. 5 independent experiments have been performed on the replicates.

a.2 Representation learning

For each representation learning method, representations were extracted from the training set. Feature columns were standard-scaled (subtracting mean / dividing by the standard deviation) before training the models to obtain representations. The validation set was used to implement an early stopping heuristic for the training process, in the case of the deep learning models. At this point, all trained representations were saved to disk. For the deep learning models, we used grid search to find the best set of hyper-parameters.

For basic autoencoders, we train with a mini-batch of 512 randomly sampled records, and for the recurrent autoencoders we train with a mini-batch of 4 patients with full history. We use early stopping based on the validation set loss to avoid overfitting, i.e. we stop training if we observe that validation set loss is non-decreasing for 10 consecutive epochs. We additionally use the validation set to perform hyperparameter optimization over the optimal learning rate and activation functions.

a.3 Representation evaluation

For evaluating the future signal and task prediction performance, representations of the first 12 hours of a patient recording were excluded. In this way the results are not affected by the model-specific ways of handling incomplete histories, which occur at the beginning of the patient stay.

a.4 Model complexity

Table A2 shows the traits of the unsupervised learning models used in the paper. An advantage of Seq2Seq-based models is that the number of parameters they use does not depend on the length of the input time series to be compressed.

name nonlinear temporal decoder output attention number of parameters
PCA past
AE past
S2S-AE past
S2S-F future
S2S-F-A future

x

Table A2: Comparison of used unsupervised representation learning models. refers to the length of the time series to be encoded (12 in our experiments), is the dimension of the input data, and is the dimension of the hidden state of the LSTM in the S2S-based models, which is the same as the representation dimension.

a.5 Impact of representation dimension

In this section we investigate the relationship between the dimensionality of representations and their performance across tasks. In the previously described experiments, we used a representation dimension of , implying a compression factor of 12 (as the windows consist of 12 hourly measurements of 94 variables). Here we vary the value of to explore how much compression is possible while retaining prediction performance.

Table A3 shows the AUROC values using S2S-F-A representations for prediction. Compared with the AUROC scores corresponding to using raw features in Table 2, even the S2S-F-A representations with very low dimension still obtain reasonable performance.

AUROC S2S-F-A (m=2) S2S-F-A (m=50) S2S-F-A (m=94)
24h Discharge
24h Mortality
Table A3: AUROC scores of predictions using LSTM-1 classifiers on S2S-F-A representations with different dimensions.