Self-supervised Transformer for Multivariate Clinical Time-Series with Missing Values

Multivariate time-series (MVTS) data are frequently observed in critical care settings and are typically characterized by excessive missingness and irregular time intervals. Existing approaches for learning representations in this domain handle such issues by either aggregation or imputation of values, which in-turn suppresses the fine-grained information and adds undesirable noise/overhead into the machine learning model. To tackle this challenge, we propose STraTS (Self-supervised Transformer for TimeSeries) model which bypasses these pitfalls by treating time-series as a set of observation triplets instead of using the traditional dense matrix representation. It employs a novel Continuous Value Embedding (CVE) technique to encode continuous time and variable values without the need for discretization. It is composed of a Transformer component with Multi-head attention layers which enables it to learn contextual triplet embeddings while avoiding problems of recurrence and vanishing gradients that occur in recurrent architectures. Many healthcare datasets also suffer from the limited availability of labeled data. Our model utilizes self-supervision by leveraging unlabeled data to learn better representations by performing time-series forecasting as a self-supervision task. Experiments on real-world multivariate clinical time-series benchmark datasets show that STraTS shows better prediction performance than state-of-the-art methods for mortality prediction, especially when labeled data is limited. Finally, we also present an interpretable version of STraTS which can identify important measurements in the time-series data.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

06/06/2016

Recurrent Neural Networks for Multivariate Time Series with Missing Values

Multivariate time series data in practical applications, such as health ...
10/06/2020

A Transformer-based Framework for Multivariate Time Series Representation Learning

In this work we propose for the first time a transformer-based framework...
01/25/2021

Multi-view Integration Learning for Irregularly-sampled Clinical Time Series

Electronic health record (EHR) data is sparse and irregular as it is rec...
04/30/2019

Multi-resolution Networks For Flexible Irregular Time Series Modeling (Multi-FIT)

Missing values, irregularly collected samples, and multi-resolution sign...
07/24/2020

Real-World Multi-Domain Data Applications for Generalizations to Clinical Settings

With promising results of machine learning based models in computer visi...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Time-series data routinely occurs in critical care settings where various measurements are recorded for patients throughout their course of stay (Figure 1

). Predicting clinical outcomes like mortality, decompensation, length of stay, and disease risk from such complex multi-variate data can facilitate both effective management of critical care units and automatic personalized treatment recommendation for patients. The successes of deep learning in image and text domains realized by CNNs, RNNs

(sutskever2014sequence; chung2014empirical), and Transformers (vaswani2017attention) have inspired the application of these architectures to develop better prediction models for time-series data as well. However, time-series in the clinical domain portray a unique set of challenges that are described below.

  • Missingness and Sparsity: A patient’s condition may demand observing only a subset of variables of interest. Thus, not all the variables are observed for every patient. Also, the observed time-series matrices are very sparse as some variables may be measured more frequently than others for a given patient.

  • Irregular time intervals and Sporadicity: Clinical variables are not usually measured at regular time intervals. Thus, the measurements occur sporadically in time depending on the underlying condition of the patient.

  • Limited labeled data

    : Patient-level clinical data is often expensive to obtain and labeled data subsets pertaining to a specific prediction task may be even more limited (for e.g. building a severity classifier for Covid-19 patients.)

A straight-forward approach to deal with irregular time intervals and missingness is to aggregate measurements into discrete time intervals and add missingness indicators respectively. However, this suppresses important fine-grained information because the granularity of observed time-series may differ from patient to patient based on the underlying medical condition. Existing sequence models for clinical time-series (che2018recurrent)

and other interpolation-based models

(shukla2019interpolation) address this issue by learnable imputation or interpolation strategies. Such techniques add undesirable noise and extra overhead to the model which usually worsens as the time-series become increasingly sparse. It is also unreasonable to impute clinical variables without careful consideration of the domain knowledge about each variable which is nontrivial to obtain.

Considering these shortcomings, we design a framework that does not need to perform any such operations and directly builds a model based on only the observations that are available in the data. Thus, unlike conventional approaches which view each time-series as a matrix of dimensions #features x #time-steps, our model regards each time-series as a set of observation triplets (a triple containing time, variable, and value). The proposed STraTS (which stands for Self-supervised Transformer for Time-S

eries) model embeds these triplets by using a novel Continuous Value Embedding (CVE) scheme to avoid the need for binning continuous values before embedding them. The use of CVE for embedding time preserves the fine grained information which is lost when the time-axis is discretized. STraTS encodes contextual information of observation triplets by using a Transformer-based architecture with multi-head attention. We choose this over recurrent neural network (RNN) architectures because the sequential nature of RNN models hinders parallel processing while the Transformer bypasses this by using self-attention to attend from every token to every other token in a single step.

Figure 1: An illustrative example of a multivariate clinical time-series data with irregular time points and missing values.

To build superior representations using limited labeled data, we employ self-supervision and develop a time-series forecasting task to pretrain STraTS. This enables learning generalized representations in the presence of limited labeled data and alleviates sensitivity to noise. Furthermore, interpretable models are usually preferred in healthcare but existing deep models for clinical time-series lack this attribute. Thus, we also propose an interpretable version of our model (I-STraTS) which slightly compromises on performance metrics but can identify important measurements in the input. Though we evaluate the proposed model only on binary classification tasks, note that the framework can be utilized in other supervised and unsupervised settings as well, where learning robust and generalized representations of sparse sporadic time-series is desired. The main contributions of our work are summarized below.

  • We propose a Transformer-based architecture called STraTS for clinical time-series which addresses the unique characteristics of missingness and sporadicity of such data by avoiding aggregation and imputation.

  • We propose a novel Continuous Value Embedding (CVE) mechanism using a one-to-many feed-forward network to embed continuous times and measured values in order to preserve fine grained information.

  • STraTS utilizes forecasting as a self-supervision task to leverage unlabeled data to learn more generalized and robust representations.

  • We also propose an interpretable version of STraTS that can be used when this is more desired compared to quantitative performance gains.

  • Experiments demonstrate that the design choices of STraTS lead to its better performance over competitive baseline models for mortality prediction on two real-world datasets.

The rest of the paper is organized as follows. In section 2

, we review relevant literature about tackling sparse and sporadic time-series data, and self-supervised learning. Section

3 formally defines the prediction problem and gives a detailed description of the architecture of STraTS along with the self-supervision approach. Section 4 presents experimental results comparing STraTS with various baselines and demonstrates the interpretability of I-STraTS with a case study. Finally, section 5 concludes the paper and provides future directions.

2 Related Work

2.1 Clinical Time-Series

A straightforward approach to address missing values and irregular time intervals is to impute and aggregate the time-series respectively, before feeding them to a classifier. However, such classifiers ignore the missingness in data which can be quite informative. lipton2016directly show that phenotyping performance can be improved by passing missingness indicators as additional features to an RNN classifier.

Several early works rely on Gaussian Processes (GP) (rasmussen2003gaussian) to model irregular time-series. For example, lu2008reproducing

represent each time-series as a smooth curve in a RKHS using GP by optimizing GP parameters using Expectation Maximization(EM), and then derive a distance measure on the RKHS which is used to define the SVM classifier’s kernel. To account for uncertainty in GP which is ignored in the former,

li2015classification formulate the kernel by applying an uncertainty-aware base kernel (called the expected Gaussian kernel) to a series of sliding windows. These works take a two-step approach by first optimizing GP parameters and then training the classification model. To enable end-to-end training, li2016scalable

again represent time-series using GP posterior at predefined time points but use the reparametrization trick by back-propagating the gradients through a black-box classifier (learnable by gradient-descent) into the GP model. The end-to-end model is uncertainty-aware as the output is formulated as a random variable as well.

futoma2017learning extend this idea to multivariate time-series with the help of multitask GP (bonilla2008multi) to consider inter-variable similarities. Though GP provide a systematic way to deal with uncertainty, they are expensive to learn and their flexibility is limited by the choice of covariance and mean functions.

shukla2019interpolation also propose an end-to-end method that constitutes interpolation and classification networks stacked in a sequence. However, the learnable interpolation layers approximate the time-series at regular predefined time points in a deterministic fashion (unlike GP-based methods) and allow information sharing across both time and variable dimensions.

Other approaches modify traditional recurrent architectures for clinical time-series to deal with missing values and/or irregular time intervals. For example, baytas2017patient

developed a time-aware long-short term memory (T-LSTM) which is a modification of the LSTM cell to adjust hidden state according to the irregular time gaps. ODE-RNN

(rubanova2019latent) uses ODEs to model the continuous-time dynamics of the hidden state while also updating the hidden state at each observed time point using a standard GRU cell. The GRU-D (che2018recurrent) model is a modification of the GRU cell which decays inputs (to global means) and hidden states through unobserved time intervals. DATA-GRU (tan2020data), in addition to decaying the GRU hidden state according to elapsed time, also employs a dual attention mechanism based on missingness and imputation reliability to process inputs before feeding them to a GRU cell.

The imputation/interpolation schemes in the models discussed above can lead to excessive computations and unnecessary noise particularly when missing rates are quite high. Our model is designed to circumvent this issue by representing sparse and irregular time-series as a set of observations. horn2019set develop SeFT with a similar idea and use a parametrized set function for classification. The attention-based aggregation used in SeFT contains the same queries for all observations to facilitate low memory and time complexity while compromising on accuracy. The initial embedding in SeFT contains fixed time encodings while our approach uses learnable embeddings for all the three components (time, variable, value) of the observation triplet.

The challenge of training in scenarios with limited labeled data still remains. In order to address this issue, we turn towards self-supervision in order to better utilize the available data to learn effective representations.

2.2 Self-supervised learning

It is well known that the more data that is available to the deep learning model, the more generalized and robust its learned representations are. Limited data can make the model easily overfit to training data and make the model more sensitive to noise. As labeled data is expensive to obtain, self-supervised learning was introduced as a technique to solve this challenge by constructing proxy tasks using a semi-automatic label generation process (liu2020self). Though this technique has shown great performance boosts with image (jing2020self) and text (devlin2018bert; yang2019xlnet) data, its application to time-series data has been limited. One such effort is made by jawed2020self who use a 1D CNN for dense univariate time-series classification and show increased accuracy by using forecasting as an additional task in a muti-task learning framework. In our work, we also demonstrate time-series forecasting as a viable and effective self-supervision task. Our work is the first to explore self-supervised learning in the context of sparse and irregular multivariate time-series.

Figure 2: An illustration of input and output construction for target and self-supervision (forecast) tasks. Note that there are several possibilities for segmenting the time axis for forecast task and only one such possibility is shown here.

3 Proposed Approach

In this section, we describe our STraTS model by first introducing the problem with relevant notation and definitions and then explaining the different components of the model which are illustrated in Figure 3.

3.1 Problem Definition

As stated in the previous sections, STraTS represents a time-series as a set of observation triplets. Formally, an observation triplet is defined as a triple where is the time, is the feature/variable, and is the value of the observation. A multivariate time-series of length is a defined as a set of observation triplets i.e., .

Consider a dataset with labeled samples, where the

sample contains a demographic vector

, a multivariate time-series , and a corresponding binary label . In this work, each sample corresponds to a single ICU stay where several clinical variables of the patient are measured at irregular time intervals and the binary label indicates in-hospital mortality. The underlying set of time-series variables denoted by may include vitals (such as temperature), lab measurements (such as hemoglobin), and input/output events (such as fluid intake and urine output). Thus, the target task aims to predict given .

Our model also incorporates forecasting as a self-supervision task. For this task, we consider a bigger dataset with samples given by . Here, is the forecast mask which indicates whether each variable was observed in the forecast window and

contains the corresponding variable values. The forecast mask is necessary because the unobserved forecasts cannot be used in training and are hence masked out in the loss function. The time-series in this dataset are obtained from both the labeled and unlabeled time-series by considering different observation windows. Figure

2 illustrates the construction of inputs and outputs for the target task and forecast task.

Notation Definition
# Time-series for target task
# Time-series for forecast task
Demographics vector
Set of clinical variables
Time of observation
Variable of observation
Value of observation
Observation triplet
Multivariate time-series
True and predicted outputs for target task
True and predicted outputs for forecast task
Forecast mask
CVE for time and value
Variable embedding
Initial triplet embedding
Time-series embedding
Demographics embedding
Table 1: Notations used in this paper.

3.2 The Proposed STraTS Architecture

The architecture of STraTS is illustrated in Figure 3. Unlike most existing approaches which take time-series matrix as the input, STraTS defines its input as a set of observation triplets. Each observation triplet in the input is embedded using the Initial Triplet Embedding module. The initial triplet embeddings are then passed through a Contextual Triplet Embedding module which utilizes the Transfomer architecture to encode the context for each triplet. The Fusion Self-attention module then combines these contextual embeddings via self-attention mechanism to generate an embedding for the input time-series which is concatenated with demographics embedding and passed through a feed-forward network to make the final prediction. The notations used in the paper are summarized in Table 1.

3.2.1 Initial Triplet Embedding

Given an input time-series , the initial embedding for the triplet is computed by summing the following component embeddings: (i) Feature embedding , (ii) Value embedding , and (iii) Time embedding . In other words, . Feature embeddings are obtained from a simple lookup table just like word embeddings. Since feature values and times are continuous unlike feature names which are categorical objects, we cannot use a lookup table to embed these continuous values unless they are categorized. Some researchers (vaswani2017attention; yin2020identifying) have used sinusoidal encodings to embed continuous values. We propose a novel continuous value embedding (CVE) technique using a one-to-many Feed-forward Network(FFN) with learnable parameters i.e. , and .

Both the FFNs have one input neuron and

output neurons and a single hidden layer with neurons and activation. They are of the form where the dimensions of weights can be inferred from the size of hidden and output layers. Unlike sinusoidal encodings with fixed frequencies, this technique offers more flexibility by allowing end-to-end learning of continuous value and time embeddings without the need to categorize them.

Figure 3: The overall architecture of the proposed STraTS model. The Input Triplet Embedding module embeds each observation triplet, the Contextual Triplet Embedding module encodes contextual information for the triplets, the Fusion Self-Attention module computes times series embedding which is concatenated with demographics embedding and passed through a dense layer to generate predictions for target and self-supervision (forecast) tasks.

3.2.2 Contextual Triplet Embedding

The initial triplet embeddings are then passed through a Transformer architecture (vaswani2017attention) with blocks, each containing a Multi-Head Attention (MHA) layer with attention heads and an FFN with one hidden layer. Each block takes input embeddings and outputs the corresponding output embeddings that capture contextual information. MHA layers use multiple attention heads to attend to information contained in different embedding projections in parallel. The computations of the MHA layer are given by

Each head projects the input embeddings into query, key, and value subspaces using matrices . The queries and keys are then used to compute attention weights which are used to compute weighted averages of values. Finally, the outputs of all heads are concatenated and projected to original dimension with . The FFN layer takes the form

with weights

. Dropout, residual connections, and layer normalization are added for every MHA and FFN layer. Also, attention dropout randomly masks out some positions in the attention matrix before the softmax computation during training. The output of each block is fed as input to the succeeding one, and the output of the last block gives the contextual triplet embeddings

.

3.2.3 Fusion Self-attention

After computing contextual embeddings using a Transformer, we fuse them using a self-attention layer to compute time-series embedding . This layer first computes attention weights by passing each contextual embedding through an FFN and computing a softmax over all the FFN outputs.

are the weights of this attention network which has neurons in the hidden layer. The time-series embedding is then computed as

3.2.4 Demographics Embedding

We realize that demographics can be encoded as triplets with a default value for time. However, we found that the prediction models performed better in our experiments when demographics are processed separately as follows by passing through an FFN. The demographics embedding is thus obtained as

where the hidden layer has dimension .

3.2.5 Prediction Head

The final prediction for target task is obtained by passing the concatenation of demographics and time-series embeddings through a dense layer with sigmoid activation.

The model is trained on the target task using cross-entropy loss.

3.2.6 Self-supervision

We experimented with both masking and forecasting as pretext tasks for providing self-supervision and found that forecasting improved the results. The forecasting task uses the same architecture as the target task except for the prediction layer i.e.

A masked MSE loss is used for training on the forecast task to account for missing values in the forecast outputs. Thus, the loss for self-supervision is given by

where (or ) if the ground truth forecast is available (or unavailable) for variable in sample. The model is first pretrained on the self-supervision task and is then fine-tuned on the target task.

3.3 Interpretability

We also propose of an interpretable version of our model which we refer to as I-STraTS. Inspired by choi2016retain and zhang2020inprem, we alter the architecture of STraTS in such a way that the output can be expressed using a linear combination of components that are derived from individual features. Specifically, the output of I-STraTS is formulated as

Contrary to STraTS, (i) we combine the initial triplet embeddings using the attention weights in Fusion Self-attention module, and (ii) directly use the raw demographics vector as the demographics embedding. The above equation can also be written as

(1)

Thus, we assign a ‘contribution score’ to the demographic feature as and to the time-series observation as .

4 Experiments

We evaluated our proposed STraTS model against state-of-the-art baselines on two real-world EHR databases for the mortality prediction task. This section starts with a description of the datasets and baselines, followed by a discussion of results focusing on generalization and interpretability.

MIMIC-III PhysioNet-2012
# ICU stays 52,871 11,988
# ICU stays (supervised) 44,812 11,988
# Avg. span of time-series 101.9h 47.3h
# Avg. span of time-sries (supervised) 23.5h 47.3h
# Variables 129 37
Avg. variable missing rate 89.7% 79.7%
Avg. # observations/stay 401 436
Demographics Age, Gender Age, Gender, Height, ICU Type
Task 24-hour mortality 48-hour mortality
% positive class 9.7% 14.2%
Table 2: Basic dataset statistics. (Avg. variable missing rate and Avg. # observations/stay are calculated using supervised samples only.)

4.1 Datasets

We experiment with time-series extracted from two real-world EHR datasets which are described below. The dataset statistics are summarised in Table 2.

MIMIC-III (mimiciii): This is a publicly available database containing medical records of about critical care patients in Beth Israel Deaconess Medical Center between 2001 and 2012. We filtered ICU stays to include only adult patients and extracted features from the following tables: input events, output events, lab events, chart events, and prescriptions for each ICU stay. For mortality prediction task, we only include ICU stays that lasted for atleast one day with the patient alive at the end of first day, and predict in-hospital mortality using the first hours of data. For forecasting, the set of observation windows is defined (in hours) as and the prediction window is the -hour period following the observation window. Note that we only consider those samples which have atleast one time-series measurement in both observation and prediction windows. The data is split at patient level into training, validation, and test sets in the ratio .

PhysioNet Challenge 2012: This processed dataset from Physionet Challenge 2012 111https://physionet.org/content/challenge-2012/1.0.0/ contains records of ICU stays of adult patients. The target task aims to predict in-hospital mortality given the first hours of data for each ICU stay. Since demographic variables ‘gender’ and ‘height’ are not available for all ICU stays, we perform mean imputation and add missingness indicators for them as additional demographic variables. To generate inputs and outputs for forecasting, the set of observation windows is defined (in hours) as and the prediction window is the -hour period following the observation window. The data from set-b and set-c together is split into training and validation (80:20) while set-a is used for testing.

4.2 Baseline Methods

To demonstrate the effectiveness of STraTS over the state-of-the-art, we compare it with the following baseline methods.

  • GRU (chung2014empirical): The input is a time-series matrix with hourly aggregation where missing variables are mean-imputed. Binary missingness indicators and time (scaled to [0,1]) since the last observation of each variable are also included as additional features at each time step. The final hidden state is transformed by a dense layer to generate output.

  • TCN (bai2018empirical): This model takes the same input as GRU which is passed through a stack of temporal convolution layers with residual connections. The representation from the last time step of the last layer is transformed by a dense layer to generate output.

  • SaND (song2018attend): This model also has the same input representation as GRU and the input is passed through a Transformer with causal attention and a dense interpolation layer.

  • GRU-D (che2018recurrent): The GRU-D cell takes a vector of variable values at each time one/more measurements are seen. The GRU-D cell, which is a modification to the GRU cell, decays unobserved values in this vector to global mean values and also adjusts the hidden state according to elapsed times since the last observation of each variable.

  • InterpNet (shukla2019interpolation): This model consists of a semi-parametric interpolation network that interpolates all variables at regular predefined time points, followed by a prediction network which is a GRU. It also uses a reconstruction loss to enhance the interpolation network. The input representation is similar to that of GRU-D and therefore, no aggregation is performed.

  • SeFT (horn2019set): This model also inputs a set of observation triplets, similar to STraTS. It uses sinusoidal encodings to embed times and the deep network used to combine the observation embeddings is formulated as a set function using a simpler but faster variation of multi-head attention.

For all the baselines, we use two dense layers to get the demographics encoding and concatenate it to the time-series representation before the last dense layer. All the baselines use sigmoid activation at the last dense layer for mortality prediction. The time-series measurements (by variable) and demographics vectors are normalized to have zero mean and unit variance. All models are trained using the Adam

(kingma2014adam) optimizer. More implementation details are provided in the appendix.

4.3 Evaluation Metrics

The following metrics are used to quantitatively compare the baselines and proposed models for the binary classification task of mortality prediction.

  1. [topsep=0mm]

  2. ROC-AUC: Area under ROC curve.

  3. PR-AUC: Area under precision-recall curve.

  4. min(Re, Pr): This metric is computed as the maximum of ‘minimum of recall and precision’ across all thresholds.

ROC-AUC PR-AUC min(Re,Pr)
MIMIC-III GRU
TCN
SAnD
GRU-D
InterpNet
SeFT
STraTS
PhysioNet-2012 GRU
TCN
SAnD
GRU-D
InterpNet
SeFT
STraTS
Table 3:

Mortality prediction performance on MIMIC-III and PhysioNet-2012 datasets. The results show mean and standard deviation of metrics after repeating the experiment

times by sampling labeled data each time.
Figure 4: Mortality prediction performance on MIMIC-III for different percentages of labeled data averaged over runs.
Figure 5: Mortality prediction performance on PhysioNet-2012 dataset for different percentages of labeled data averaged over runs.

4.4 Prediction Performance

We train each model using different random samplings of labeled data from the train and validation sets. Note that STraTS uses the entire labeled data and additional unlabeled data (if available) for self-supervision. Table 3 shows the results for mortality prediction on MIMIC-III and PhysioNet-2012 datasets which are averaged over the runs. STraTS achieves the best performance on all metrics, improving PR-AUC by and on MIMIC-III and PhysioNet-2012 datasets respectively over the best baseline. This shows that our design choices of triplet embedding, attention-based architecture, and self-supervision enable STraTS to learn superior representations. We expected the interpolation-based models GRU-D and InterpNet to outperform the simpler models GRU, TCN, and SaND. This was true for all cases except that GRU showed a better performance than GRU-D and InterpNet on the MIMIC-III dataset, which needs to be investigated further.

To test the generalization ability of different models, we evaluate STraTS and the baseline models by training them on varying percentages of labeled data. Lower proportions of labeled data can be observed in the real-world when there are several right-censored samples. Figures 4 and 5 show the results for MIMIC-III and PhysioNet-2012 datasets, respectively. The performance of all models declines with reduced labeled data. But STraTS is seen to have a crucial advantage compared to other models in lower labeled data settings which can be attributed to self-supervision.

4.5 Ablation Study

We compared the predictive performance of STraTs and I-STraTS, with and without self-supervision. The results are shown in Table 4. ‘ss+’ and ‘ss-’ are used to indicate models trained with and without self-supervision respectively. We observe that: (i) Adding interpretability to STraTS hurts the prediction scores as a result of constraining model representations. (ii) Adding self-supervision improves performance of both STraTS and I-STraTS. (iii) I-STraTS(ss+) outperforms STraTS(ss-) on all metrics on MIMIC-III dataset, and on the PR-AUC metric for PhysioNet-2012 dataset. This demonstrates that the performance drop from introducing interpretability can be compensated by the performance gains of self-supervision.

ROC-AUC PR-AUC min(Re,Pr)
MIMIC-III I-STraTS (ss-)
I-STraTS (ss+)
STraTS (ss-)
STraTS (ss+)
PhysioNet-2012 I-STraTS (ss-)
I-STraTS (ss+)
STraTS (ss-)
STraTS (ss+)
Table 4: Ablation Study: Comparing mortality prediction performance of STraTS and I-STraTS with and without self-supervision. (‘ss+’ and ‘ss-’ are used to indicate models trained with and without self-supervision respectively.)

4.6 Interpretability

To illustrate how I-STraTS explains its predictions, we present a case study for an year old female patient from the MIMIC-III dataset who expired on the

day after ICU admission. The model I-STraTS predicts the probability of her in-hospital mortality as

using data collected just on the first day. The patient had measurements belonging to time-series variables. The top variables ordered by their ‘cumulative ‘contribution score’ along with the range (if multiple observations) or value (if one observation) are shown in Table 5. We see that I-STraTS considers the abnormal Hematocrit values and old age as the most important observations in predicting that the patient is at high risk of mortality. Such predictions can not only guide the healthcare system in identifying high-risk patients for better resource allocation but also guide the clinicians into understanding the contributing factors and make better diagnoses and treatment choices.

To get a more fine-grained intuition, the observed time-series for some variables in this ICU stay are plotted in Figure 6 along with the corresponding contribution scores. It is interesting to see that the contribution scores appear to be positively or negatively correlated with the underlying values. For example, as Hct decreases, the model gives more weight to the measurement. Similarly, as GCS-eye increases, the model pays less attention to it. Higher FiO2 implies that the patient is under ventilation and is hence considered important. The contribution scores of BP time-series also exhibit a pattern. Lower and more recent values of SBP and DBP contribute more towards the final prediction.

Variable Range/Value ‘contribution score’
Hematocrit [28.7, 30.8] 0.448
Age 85.0 0.395
Phosphate [2.7, 3.5] 0.237
RBC [3.0, 3.1] 0.116
MCV [95.0, 98.0] 0.087
MCHC [32.5, 33.9] 0.077
Potassium [3.8, 4.7] 0.072
Bilirubin (Total) [0.7, 0.8] 0.065
Table 5: Case study: Top variables ordered by ‘contribution score’ for an ICU stay from MIMIC-III dataset.
Figure 6: Case study: An illustration of few time-series with contribution scores for a patient from MIMIC-III dataset.

5 Conclusion

We proposed a Transformer-based model, STraTS, for prediction tasks on multivariate clinical time-series to address the challenges faced by existing methods in this domain. Our approach of using observation triplets as time-series components avoids the problems faced by aggregation and imputation methods for sparse and sporadic multivariate time-series. We leave it for future work to develop heuristics to quantify the gains of triplet-based representations over aggregation and interpolation based ones, in terms of accuracy and time-and-space complexity, based on the degree of sparsity and sporadicity in data. The self-supervision task of forecasting using unlabeled data enables STraTS to learn more generalized representations, thus outperforming state-of-the-art baselines. This motivates us to explore the effectiveness of more self-supervision tasks for clinical time-series data. We also proposed an interpretable version of STraTS called I-STraTS for which self-supervision compensates the drop in prediction performance from introducing interpretability.

References

Appendix A Appendix

a.1 Implementation details

Table 6

lists the hyperparameters used in the experiments for all models for MIMIC-III and PhysioNet-2012 datasets. All models are trained using a batch size of

with Adam optimizer and training is stopped when sum of ROC-AUC and PR-AUC does not improve for epochs. For pretraining phase using the self-supervision task, the patience is set to epochs and epoch size is set to samples. For MIMIC-III dataset, we set the maximum number of time-steps for GRU-D and InterpNet, and the maximum no. of observations for STraTS using the

percentile for the same. This is done to avoid memory overflow during batch training. The deep models are implemented using keras with tensorflow backend. For InterpNet, we adapted the official code from

https://github.com/mlds-lab/interp-net. For GRU-D and SeFT, we borrowed implementations from https://github.com/BorgwardtLab/Set_Functions_for_Time_Series. The experiments are conducted on a single NVIDIA GRID P40-12Q GPU. The implementation is publicly available at https://github.com/sindhura97/STraTS.

Model MIMIC-III PhysioNet-2012
GRU units=50, rec d/o=0.2, output d/o=0.2, lr=0.0001 units=43 rec d/o=0.2, output d/o=0.2, lr=0.0001
TCN layers=4, filters=128, kernel size=4, d/o=0.1, lr=0.0001 layers=6, filters=64, kernel size=4, d/o=0.1, lr=0.0005
SAnD N=4, r=24, M=12, d/o=0.3, d=64, h=2, he=8, lr=0.0005 N=4, r=24, M=12, d/o=0.3, d=64, h=2, he=8, lr=0.0005
GRU-D units=60, rec d/o=0.2, output d/o=0.2, lr=0.0001 units=49 rec d/o=0.2, output d/o=0.2, lr=0.0001
SeFT lr=0.001, n phi layers=4, phi width=128, phi dropout=0.2, n psi layers=2, psi width=64, psi latent width=128, dot prod dim=128, n heads=4, attn dropout=0.5, latent width=32, n rho layers=2, rho width=512, rho dropout=0.0, max timescale=100.0, n positional dims=4 lr=0.00081, n phi layers=4, phi width=128, phi dropout=0.2, n psi layers=2, psi width=64, psi latent width=128, dot prod dim=128, n heads=4, attn dropout=0.5, latent width=32, n rho layers=2, rho width=512, rho dropout=0.0, max timescale=100.0, n positional dims=4
InterpNet ref points=96, units=100, input d/o=0.2, rec d/o=0.2, lr=0.001 ref points=192, units=100, input d/o=0.2, rec d/o=0.2, lr=0.001
STraTS(ss-) & I-STraTS(ss-) d=32, M=2, h=4, d/o=0.2, lr=0.0005 d=32, M=2, h=4, d/o=0.2, lr=0.001
STraTS & I-STraTS d=50, M=2, h=4, d/o=0.2, lr=0.0005 d=50, M=2, h=4, d/o=0.2, lr=0.0005
Table 6: Hyperparameters used for experiments in this paper.