SCEHR: Supervised Contrastive Learning for Clinical Risk Prediction using Electronic Health Records

10/11/2021 ∙ by Chengxi Zang, et al. ∙ cornell university 9

Contrastive learning has demonstrated promising performance in image and text domains either in a self-supervised or a supervised manner. In this work, we extend the supervised contrastive learning framework to clinical risk prediction problems based on longitudinal electronic health records (EHR). We propose a general supervised contrastive loss ℒ_Contrastive Cross Entropy + λℒ_Supervised Contrastive Regularizer for learning both binary classification (e.g. in-hospital mortality prediction) and multi-label classification (e.g. phenotyping) in a unified framework. Our supervised contrastive loss practices the key idea of contrastive learning, namely, pulling similar samples closer and pushing dissimilar ones apart from each other, simultaneously by its two components: ℒ_Contrastive Cross Entropy tries to contrast samples with learned anchors which represent positive and negative clusters, and ℒ_Supervised Contrastive Regularizer tries to contrast samples with each other according to their supervised labels. We propose two versions of the above supervised contrastive loss and our experiments on real-world EHR data demonstrate that our proposed loss functions show benefits in improving the performance of strong baselines and even state-of-the-art models on benchmarking tasks for clinical risk predictions. Our loss functions work well with extremely imbalanced data which are common for clinical risk prediction problems. Our loss functions can be easily used to replace (binary or multi-label) cross-entropy loss adopted in existing clinical predictive models. The Pytorch code is released at <https://github.com/calvin-zcx/SCEHR>.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

With the accumulation and better availability of electronic health records (EHR) [28, 32]

, health analytics becomes one of the most important frontiers for data mining and artificial intelligence

[35]. Public EHR databases [14] and benchmark suite [12]

provide great resource to develop advanced data mining and machine learning algorithms for critical clinical risk prediction problems including in-hospital mortality prediction, disease phenotyping, hospital readmission, etc.

[12, 30]. These problems can be formulated as a binary or multi-label classification problem using longitudinal EHR event sequence (by concatenating visits of individual patients over time) and solved by minimizing its corresponding classification loss [e.g. (multi-label or binary) cross-entropy loss][12, 30, 36]

. Although great endeavors have been devoted to developing complex deep learning models for these clinical risk prediction problems

[12, 9, 21, 19, 4, 25, 7, 33, 26, 24, 1], limited progress has been made over past years on these tasks regarding their performance [1]. In contrast with the majority of current research in designing more advanced predictive models, in this paper, we show that replacing widely adopted cross entropy loss by supervised contrastive loss is a promising way to improve the performance of existing models for clinical risk prediction based on longitudinal EHR data.

Fig. 1: An illustration of our SCEHR. We propose a general supervised contrastive learning loss for clinical risk prediction problems using longitudinal electronic health records. The overall goal is to improve the performance of binary classification (e.g. in-hospital mortality prediction) and multi-label classification (e.g. phenotyping) by pulling () similar samples closer and pushing () dissimilar samples apart from each other. tries to contrast sample representations with learned positive and negative anchors, and tries to contrast sample representations with others in a mini-batch according to their labels. For brevity, we only highlight the contrastive pulling and pushing forces associated with sample in a mini-batch consisting of two positive samples and three negative samples.

Recently, contrastive learning [20], which aims at learning data instance representations by bringing similar instances closer and push dissimilar instances further away from each other, has shown promising results in image classifications [5, 2], medical image understanding [38], and so on [23]. These methods mainly follow a self-supervised strategy [23, 13]

, which build augmented data with pseudo-labels to deal with the issue of lacking sufficient supervised information. The latest research finds that supervised information can provide additional benefits for contrastive learning in both computer vision

[16]

and natural language processing tasks

[10]. We argue that the general idea of contrastive learning should also be helpful for clinical risk prediction tasks. However, application of contrastive learning in clinical risk prediction scenarios is challenging because: 1) the patient data (such as EHRs) for clinical risk prediction are usually more complex than images or texts in that the clinical events involved are of mixed types, high-dimensional, sparse and noisy; 2) it is challenging to augment EHR with computational methods because of the intrinsic complexity of disease mechanisms; 3) predicted clinical outcomes could also be heterogeneous. Therefore, if contrastive learning strategies can be beneficial to clinical risk prediction problems is still an open question.

In this paper, we propose SCEHR, a Supervised Contrastive learning framework for clinical risk predictions using longitudinal Electronic Health Record data. We illustrate the idea of SCEHR in Figure 1. The key component of SCEHR is a general supervised contrastive loss for solving binary classification (e.g. in-hospital mortality prediction) and multi-label classification (e.g. phenotyping) in a unified framework. We propose two versions (Eq. 10 and Eq. 11) of the above supervised contrastive loss to implement the key idea of contrastive learning, i.e., pulling similar samples closer and pushing dissimilar ones apart from each other, which can be achieved by minimizing the two components of our . Specifically, for an arbitrary neural encoder that maps clinical time series into embedding representations, the learns a positive anchor and a negative anchor (for each class) respectively and tries to contrast the distance between targeted samples and the learned positive anchor versus the distance between the targeted samples and the learned negative anchor, guided by the supervised labels (e.g. positive/dead for in-hospital mortality prediction, or existence of some medical concepts for phenotyping classification). The tries to contrast every pair of samples with the same labels versus every pair of samples with different labels in a mini-batch. By leveraging supervised information, SCEHR doesn’t need data augmentation and pseudo-labels. In addition, we also demonstrate the relationship between and the triplet loss [3].

We validate SCEHR together with two versions of our proposed supervised contrastive losses on benchmarking clinical risk prediction tasks, including in-hospital mortality prediction and phenotyping [12], on a big real-world EHR database (MIMIC-III) [14]. We find that both versions of our proposed loss functions can improve strong baseline models and state-of-the-art models. We further investigate our modeling performance when the level of data imbalance changes. We find that our proposed loss functions work much better than binary cross entropy loss under extreme imbalance situation (say, positive ratio ), which is common in prediction problems with rare clinical outcomes. We further visualize our learned embeddings to interpret the effects of our proposed supervised contrastive losses. It is worthwhile to highlight our contributions as follows:

  • Novelty. We propose a general supervised contrastive loss and its two instances for solving supervised binary classification and multi-label classification in a unified framework. SCEHR is one of the first applying supervised contrastive learning to clinical risk predictions with longitudinal EHR data.

  • Effectiveness. SCEHR can improve both strong baseline models and the state-of-the-art models for clinical risk prediction tasks, including in-hospital mortality prediction and phenotyping. SCEHR does well with extreme data imbalance situation.

  • Flexibility.

    Our proposed supervised contrastive loss functions can be easily used to replace (multi-label or binary) cross entropy loss based on existing clinical predictive models. Our PyTorch code is open-sourced at

    https://github.com/calvin-zcx/SCEHR.

The outline of this paper is: survey (Sec. II), problem definition (Sec. III), proposed method SCEHR (Sec. IV), experiments (Sec. V), and conclusions (Sec. VI).

Ii Related Work

Deep predictive models using EHR data. Applying deep models for clinical risk prediction problems (e.g. in-hospital mortality prediction, phenotyping, decompensation, length-of-stay prediction, readmissions, etc.) based on longitudinal electronic health record (EHR) data [28, 30, 32] show great potentials in improving health care. These tasks are usually formulated as binary or multi-label classification problems by optimizing multi-label or binary cross-entropy loss. Most of research endeavors have been devoted to developing more advanced deep models or trying to incorporate more data to capture the complexity of diseases and the EHR data, including but not limited to RNNs [12, 9], transformers [21], reverse distillation [19], variational inference [4]

, deep feature selection

[25], attentions [7, 33, 26, 24], an so on. However, despite the fast pace of modeling innovations, much slower progress has been made over past years on these tasks concerning their performance [1]. Instead of designing more complex deep predictive models, here we explore another direction: trying to innovate the default (binary or multi-label) cross entropy loss widely used in existing clinical predictive models. We focus on state-of-the-art models [12, 26, 1] which were benchmarked on public MIMIC-III data [14] considering limitations of using private EHR data.

Contrastive Learning. Contrastive learning [20, 23], aiming at learning good representations by bringing similar samples closer and push dissimilar samples away from each other through constructing contrastive loss functions, has shown promising results in image classifications [5, 2], medical image understanding [38], videos [11], etc. The idea of ”contrastive” loss functions can date back to metric learning [37], triplet loss [3]

, Siamese neural networks

[6], and the negative sampling loss of word2vec [27]. The majority of contrastive learning literature adopted self-supervised techniques [23, 13, 15, 22] by building augmented data with pseudo-labels. Recently, by explicitly using supervised labels, supervised contrastive learning has shown better performance for image classification [16] and NLP tasks [10]. To our best knowledge, only one paper [36] tried the contrastive idea for binary classification with EHR data, which adopted the negative sampling loss of word2vec [27] by negatively sampling on built heterogeneous information networks [8]. Different from all the above research, we propose a general supervised contrastive loss (together with its two versions) for solving binary classification and multi-label classification in a unified framework using longitudinal EHR data.

Iii Problem Definition

In this section, we define our focused clinical risk prediction problems with longitudinal electronic health records (EHR) data. Let represent one patient’s clinical time series data, which consist of -dimensional clinical concepts (e.g. individual measurements during his/her stay in ICU) over time . Specifically, represents the clinical concept (e.g. diastolic blood pressure) measured at timestamp for patient . In total, there are patients denoted as and () usually varies for different patients according to their length of stay, say, in ICU. Additional static features, e.g. demographic features, are denoted as and represents patient ’s features. For simplicity, we use to represent all the clinical time series and additional static features (if exist) for modeling. We use to denote the targeted clinical outcomes, e.g. in-hospital mortality events, the existence of phenotype conditions, etc., which will occur beyond the observational window () for each patient, and .

Our primary goal is to learn a predictive model

, which predicts the probability of the occurrence of clinical outcomes denoted as

. The are learnable modeling parameters. Regarding the value of , the above problem formulation encompasses two special cases:

  • Binary classification problem , namely, where . Tasks including in-hospital mortality prediction, physiologic decompensation, etc., belong to this category.

  • Multi-label classification problem , namely, where , which can be formulated as solving multiple binary classifications simultaneously. The phenotype classification (phenotyping) task belongs to this category.

We will detail the above tasks in the experiment sections. We learn the parameters of by minimizing the loss function:

(1)

given supervised information , and are the predicted outcomes.

In contrast with the majority of existing efforts in designing , in this paper, we show that the supervised contrastive learning loss proposed as follows is also an effective way to improve the performance of clinical predictive models.

Iv Supervised Contrastive Learning Framework for EHR

In this section, we introduce our Supervised Contrastive Learning for EHR (SCEHR) model in detail. We show the outline of our SCEHR in Figure 1 as a roadmap for this section and we summarize the overall learning process of our SCEHR in Algorithm 1.

Iv-a General Supervised Contrastive Loss

Let be any learnable neural encoder for clinical time series , which maps into its embedding representation by . We further define a linear mapping and a non-linear squeeze function (e.g. sigmoid or softmax functions) which maps the learned representations to the predicted probability by . We propose the following general form of Supervised Contrastive Loss for binary or multi-label classification problems:

(2)

Our loss consists of two parts: a (supervised) contrastive cross entropy loss which is a function of predicted labels against its ground truth labels ; and a supervised contrastive regularizer which regularizes the learned embedding representation by the supervised information . The regularizer is scaled by a non-negative hyper-parameter . We will detail several choices of the above losses for both binary classification and multi-label classification as follows.

Iv-B Contrastive Cross Entropy for Binary Classification

Let , , , and represent clinical time series of one patient, its embedding representation, its ground-truth clinical outcomes, and its predicted outcomes respectively. We use

to represent the learned anchors of positive or negative clusters respectively, which are modeled as the row vectors of the weight matrix of a linear mapping

.

The Binary Cross Entropy (BCE) loss is widely used for clinical risk classification when there are two outcomes coded as 1 or 0, say mortality for positive cases and non-mortality for negative cases. The equation for BCE loss, denoted as , is:

(3)

where

is the Sigmoid function and

. If we define a distance measure as the dot product of two data samples, intuitively, minimizing the BCE loss tries to make positive samples () close to the anchor . Similarly, for negative samples ( ), the BCE loss makes close to .

Here we propose Contrastive Binary Cross Entropy (CBCE) loss, denoted as , as follows:

(4)

which is the first version of our term. The above loss explicitly learns positive anchor and negative anchor separately. Minimizing the CBCE loss makes positive sample (when ) closer to positive anchor than to the negative anchor by pulling closer to and at the same time pushing away from . Similarly, for a negative sample (when ), minimizing the loss makes closer to negative anchor than to the positive anchor by pulling closer to and at the same time pushing away from . Intuitively, two learned anchors and represent positive cluster and negative cluster respectively, and the location of each sample representation is determined by contrasting the force with the force in a product form. We show the math of these contrastive forces in the following subsection. In all, Equation 4 contrasts each sample with positive and negative anchors in a product form.

Following the similar idea of , we can also view a two-dimensional softmax cross entropy as our second instance of the contrastive cross entropy loss . We denote Contrastive Softmax Cross Entropy (CSCE) as , which is defined by the following equation:

(5)

Equation 5 contrasts each sample with positive and negative anchors in a ratio form, which is a two-dimensional softmax function followed by a negative likelihood loss. Taking one positive sample (when ) as an example, minimizing the above loss tries to pull closer to the positive anchor than to the negative anchor by pulling to and at the same time push away from .

Iv-C Supervised Contrastive Regularizer

Compared with the which compares each sample’s distance to the learned positive anchor with its distance to the learned negative anchor, the tries to explore pair-wise relationships between data samples in a mini-batch. Specifically, the tries to pull the data pairs with the same labels closer and push data pairs with different labels away from each other. Based on the supervised contrastive loss proposed in [16], we propose a simplified Supervised Contrastive loss as the Regularizer (SCR), which is defined by the following equation:

(6)

where is the number of samples in a mini-batch, is the number of samples sharing the same label as data , , and is the positive temperature hyper-parameter. Here we do not adopt self-supervised data augmentation strategy [16, 5] and we only use existing supervised information . As a result, for each data sample , we consider its distance to other samples and contrast these pair-wise distances according to if two samples share the same label as ratio form as detailed in the Equation 6.

Iv-D Relationship with Triplet Loss

All the above contrastive losses , and can be approximated by a triplet loss. As for the , the (product form) contrastive term between sample representation and two anchors can be approximated as:

(7)

where is a positive scalar, represents learnable parameters of and . The above two approximations are achieved by and .

As for the , the (ratio form) contrastive term can be approximated as:

(8)

where the approximations are achieved by and is a positive scalar.

Though different forms, both contrastive cross entropy losses and try to make the distance between and the targeted anchor smaller than the distance between and negative anchor . Similar argument applies to the as the ratio form contrastive term . This is the major reason why all the above losses are named as contrastive.

Iv-E Generalization to Multi-label Classification

We further generalize the above binary classification losses to multi-label classification losses. A typical clinical prediction application is phenotyping which tries to predict the existences of multiple clinical conditions. We model multi-label classification as solving multiple binary classifications simultaneously. Here we define our general multi-label form of as follows:

(9)

where is the number of classes. Equation 2 is a special case of Equation 9 when .

Based on the aforementioned contrastive cross entropy losses , (sec. IV-B), and the supervised contrastive regularizer (sec. IV-C), here we propose following two versions of our general supervised contrastive loss:

  • Our general multi-label form is:

    (10)
  • Our general multi-label form is:

    (11)

It is worthwhile to mention that the above two multi-label classification losses encompass binary-classification losses as special cases when . For simplicity, we use general form to denote both binary and multi-label cases.

Iv-F Summary

Input: Data , labels
Output: :targeted neural encoder for , : positive and negative anchors for each of classes
for 

each epoch

 do
       Step 1: Sampling mini-batch
       Step 2: Generating data representations
       Step 3: Computing the supervised contrastive loss by Eq. 10 or Eq. 11
       Step 4: Updating by minimizing above loss.
end for
Return: ,
Algorithm 1 The learning framework of our SCEHR 

We summarize the overall learning framework of our SCEHR in Algorithm 1. We illustrate the main idea of our SCEHR in Figure 1. The major outputs of algorithms are the targeted neural encoder for , the learned positive anchors for each of classes, the learned negatives anchors for each of classes . The predicted probability of data belonging to the positive cases of class (e.g. the predicted risk of in-hospital mortality for mortality prediction task and represents positive/mortality) are and for Eq. 10 and Eq. 11 respectively. In general, our SCEHR can be used for existing clinical risk prediction models which are used for binary or multi-label classifications by replacing cross entropy losses with our Eq. 10 and Eq. 11. The PyTorch implementations of our SCEHR are open-sourced at https://github.com/calvin-zcx/SCEHR.

V Experiments

We validate our SCEHR on a real-world electronic health records (EHR) database, Medical Information Mart for Intensive Care (MIMI-III) [14], which is publicly available. Following benchmarking works [12], we validate our SCEHR by answering the following questions:

  • In-hospital mortality prediction (Sec. V-A) tries to predict in-hospital mortality states, namely a binary classification task, of ICU patients given their first 48-hour data in ICU. The early-prediction of at-risk patients is the key for patient stratification to improve healthcare results. Our question is: Can our SCEHR improve the performance of benchmarking models for in-hospital mortality prediction task?

  • Phenotyping classification (Sec. V-B) tries to predict the existence of 25 common clinical conditions (coded by ICD-9 codes in EHR) of patients in ICU, namely a multi-label classification task, given their data in ICU with varying length of time. The phenotyping is key for diagnosis, comorbidity detection, and quality surveillance [29]. Our question is: Can our SCEHR improve the performance of typical benchmarking models for phenotyping task?

  • Data Imbalance Analysis (Sec. V-C). Positive cases in the EHR data always make up a smaller proportion than the negative cases. Our question is: How will our SCEHR perform under different levels of data imbalance?

  • Embedding Visualization (Sec. V-D). Our SCEHR is supposed to pull similar data embeddings closer and push dissimilar ones apart. Our question is: What will the learned embeddings look like by our SCEHR on the real-world EHR data?

Datasets. Following the benchmark tasks [12] on the MIMI-III dataset [14], medical concepts (including Capillary refill rate, Diastolic blood pressure, Fraction inspired oxygen, Heart Rate, etc.) observed over time are selected as features, which are further feature-engineered into dimensional medical time series data for predictive models. As for the mortality prediction, the first hour time series are used, leading to medical time series for each patient. Besides, the latest works [26] also included additional dimensional static features based on demographics (e.g. ethnicity, gender, age, height, weight, etc.) to improve the performance. The supervised labels are for patients. As for the phenotyping classification, the time length of varies depends on the length of stay in ICU. The labels for phenotyping multi-label classification are . The splitting of the train, validation, and test datasets are summarized in Table I, and the statistics of the varying for phenotyping classification are summarized in Table II.

#Train #Validation #Test
Mortality 14,681 (13.53%) 3,222 (13.53%) 3,236 (11.56%)
Phenotyping 29,250 (16.54%) 6,371 (16.31%) 6,281 (16.53%)
TABLE I: Statistics of datasets. The ratio of positive cases is shown in the round brackets. The mortality data have binary labels, and the phenotyping data have 25-dimensional multi-labels.
Phenotyping #Train #Validation #Test
min 1 2 2
max 2804 1843 1993
mean 86.81 88.79 88.75
std. 123.87 125.56 127.66
TABLE II: Statistics of the varying length of each patient in phenotyping dataset.

We implemented our codes by Python 3.9.1, Pytorch-1.7.1, Cuda 10.1 and trained all the models on GeForce RTX 2080 Ti GPU and 16 CPU cores in Linux server with Ubuntu 18.04.2 LTS. We open-source our codes at https://github.com/**/SCL-EHR and refer to [14] for the public MIMIC-III dataset and [12] for the data pre-processing and benchmarking codes.

V-a In-hospital Mortality Prediction

AUROC AUPRC Accuracy min(Se, P+)
0.892 (0.005) 0.487 (0.023)
0.482 (0.025)
0.501 (0.030) 0.893 (0.005)
TABLE III: In-hospital mortality prediction results by benchmarking LSTM model [12]

under different losses. BCE: Binary Cross Entropy; CBCE: Contrastive Binary Cross Entropy; CSCE: Contrastive Softmax Cross Entropy; SCR: Supervised Contrastive Regularizer. We highlight the best performance w.r.t different metrics. We also report the standard deviation (std.) of bootstrapped results by re-sampling the test set 100 times with replacement in round brackets for reference.

AUROC AUPRC Accuracy min(Se, P+)
0.902 (0.005)
TABLE IV: In-hospital mortality prediction results by benchmarking Concare [26] model under different losses. Additional static demographic features are used in this experiment.

Setup. The in-hospital mortality prediction, which is formulated as a binary classification problem, is always learned by optimizing binary cross entropy (BCE) loss in existing works [12, 26]. In this task, we evaluate our SCEHR ’s capability of improving benchmark models for mortality prediction by replacing the BCE loss.

To be comparable with benchmark models, we adopt the most widely used: a) LSTM-based models (a 2-layerd LSTM model with learnable parameters) [12] ; and b) the state-of-the-art attention-based model Concare (a complex channel-wise GRU model with attention layers and using additional static demographic features, leading to learnable parameters in total) [26], and compare these models with a) their original binary cross entropy loss ; b) binary cross entropy loss with supervised contrastive regularizer ; c) our contrastive binary cross entropy loss with supervised contrastive regularizer ; d) our contrastive softmax cross entropy loss with supervised contrastive reularizer . To be consistent with baseline implementations, we control for the same learning settings, including Adam optimizer [17] with learning rate , dropout , weight decay , and only grid search for best AUROC performance among two varying hyper-parameters, namely, batch size and . The hidden dimensions of , namely the penultimate layer for contrastive learning regularizer are for LSTM and for Concare. We set the maximum epochs of training for LSTM and Concare are 100 and 150 respectively. We set the temperature for all the following experiments.

We evaluate the performance of this binary classification by the widely-adopted benchmark metrics, including AUROC

which is the area under the receiver operating characteristic curve;

AUPRC

which is the area under the precision and recall (also known as sensitivity) curve;

Accuracy which is the ratio of correctly predicted cases to the total cases; and min(Se, P+) which is the upper bound of the minimum of different sensitivity and precision pairs.

Results. Table III and Table IV show that our SCEHR improves the best performance of both the benchmark LSTM model and the state-of-the-art Concare model with respect to all the four metrics for the in-hospital mortality prediction task on the MIMIC-III dataset. More specifically, we find both two contrastive losses and outperforms w.r.t all the metrics. The achieved the best AUROC, AUPRC, Accuracy, while the achieved similar AUROC and the best min(Se, P+) for both models, regardless of the different complexity of two benchmark models. Besides, simply applying the regularizer to also improves the best AUROC performance of using bare for LSTM.

We observe similar empirical running times for different losses under the same predictive model. All the above loss functions finish 100 epochs with 256 batch size within minutes for the LSTM-based model and minutes for the Concare model.

In conclusion, or improves the performance of strong benchmarking model LSTM and the state-of-the-art Concare model by replacing BCE loss. Both two supervised contrastive terms, namely and can introduce additional performance improvement.

V-B Phenotyping Classification

Micro AUROC Macro AUROC Weighted AUROC
TABLE V: Prediction results of Phenotypes by benchmarking LSTM [12] model under different losses. BCE: Multi-label Binary Cross Entropy; CBCE: Multi-label Contrastive Binary Cross Entropy; CSCE: Multi-label Contrastive Softmax Cross Entropy; SCR: Multi-label Supervised Contrastive Regularizer. We highlight the best performance w.r.t different metrics.

Setup. The phenotyping, which is formulated as a multi-label classification problem, is learned by optimizing the mean of multiple binary cross entropy losses (BCE) in existing benchmarking models [12]. In this task, we evaluate our SCEHR ’s ability to improve the benchmarking phenotyping models by replacing the BCE loss.

We examined the LSTM-based model (a 1-layerd LSTM model with learnable parameters) 111We choose standard LSTM benchmarking model because different LSTM benchmarks in [12] have similar auroc performance, and the state-of-the-art Concare [26] can not be applied to time series with varying length. [12] under different losses, including a) multi-label cross entropy loss ; b) multi-label cross entropy loss with multi-label supervised contrastive regularizer ; c) our multi-label contrastive binary cross entropy loss with multi-label supervised contrastive regularizer ; d) our multi-label contrastive softmax cross entropy loss with multi-label supervised contrastive reularizer . We evaluate multi-label classification performance by standard metrics including Micro-AUROC, Macro-AUROC, and weighted-AUROC [31]. We adopt the same setting for consistency, including Adam optimizer with learning rate , dropout , weight decay , and we grid search for best micro-AUROC performance among two varying hyper-parameters, namely, batch size and . The hidden dimension of , namely the penultimate layer for contrastive learning regularizer is .

Results. Table V reports different AUROC scores, we find that our SCEHR improves benchmarking LSTM models w.r.t all the metrics. More specifically, our and applying directly to BCE loss achieved the best performance, indicating the benefits of introducing supervised contrastive terms.

V-C Data Imbalance Analysis

Fig. 2: In-hospital mortality prediction under different data imbalance levels.

Setup. We further investigate the performance of our loss functions when the number of positive cases in the training data is imbalanced at different levels. We studied the in-hospital mortality prediction by the benchmarking LSTM model. As shown in Table I, the original ratio of positive cases in the training dataset is . We downsample the training data with different levels of positive cases, namely, , , and , and keep the test data the same. The number (with the ratio of positive cases in the round brackets) of patients in down-sampled training datasets are , , , respectively. Follow the same experimental setting as section V-A, we search the best AUROC performance on the hyper-parameter space spanned by batch size and .

Results. We report the AUROC achieved by different losses under different data imbalance levels (the ratio of positive cases) in Figure 2. We find consistent improvements of our and over the BCE loss under different imbalance levels. Besides, introducing the self-supervised regularizer to BCE also improves, but not as significant as and . When the prevalence of positive cases is very rare, say , we find that our and outperforms BCE a lot.

In conclusion, our experimental result implies that when the focused clinical outcome is rare (e.g. rare diseases) in EHR datasets, namely, a very small fraction of positive cases among the total population, replacing the BCE loss by our and can improve binary classification performance.

V-D Embedding Visualization

(a) BCE
(b) BCE + SCL
(c) CBCE + SCL
(d) CSCE + SCL
Fig. 3: t-SNE plots of patient’s embedding representations learned by the same LSTM-based mortality predictive model under BCE and different supervised contrastive losses on the test dataset. Orange crosses and blue dots represent the positive and negative cases respectively. The positive cases account for 11.56% of the total population. We highlight the learned positive anchor by a red cross and the negative anchor by a red dot.

Setup. We here try to visualize embedding representations of each patient in the test dataset learned by different losses to illustrate the effect of supervised contrastive terms. All the representations are learned by the same LSTM-based mortality predictive model as discussed in Section V-A under different losses, including a) the BCE loss ; b) BCE loss with supervised contrastive regularizer ; c) contrastive binary cross entropy loss with supervised contrastive regularizer ; d) contrastive softmax cross entropy loss with supervised contrastive reularizer

. We control for batch size 256 for all the learning processes. We plot the 16-dimensional hidden representations

by t-SNE [34] with 50 perplexity under 1000 iterations. The t-SNE is initialized by PCA as suggested in [18].

Results. We show embedding visualizations in Figure 3. Compared with the BCE plot (Figure 3a), we find that all the loss functions with supervised contrastive terms (Figure 3b-d) better squeeze positive samples near the red cross and negative samples near the red circle, implying their ability to pull representations with the same label closer and push representations with different labels apart. What’s more, compared with , our and show more complex structures and at the same time a relatively good gap between classes, which are possible reasons accounting for their better performance. Visual inspection implies best class separation by our in Figure 3c among others, which is consistent with the best AUROC achieved by . Besides, we can also find many points that are located among data clusters with different labels, indicating the intrinsic difficulty in clinical risk predictions with longitudinal EHR data [1].

Vi Conclusion

In this paper, we propose a general supervised contrastive loss form for solving both binary classification and multi-label classification in a unified framework for clinical risk prediction using EHR data. Our proposed loss improves the performance of strong baselines and even state-of-the-art models on benchmarking clinical risk prediction using real-world longitudinal EHR data, works well with extremely imbalanced data, and can be easily used to existing clinical risk predictive models by replacing their (binary or multi-label) cross entropy loss. Our Pytorch code is released at https://github.com/calvin-zcx/SCEHR. For future work, more instances of the above supervised contrastive loss can be proposed. More clinical risk predictive models, EHR datasets, and self-supervised data augmentation techniques for longitudinal EHR data need further investigation.

Acknowledgement

This work was supported by NSF 1750326, ONR N00014-18-1-2585 and NIH RF1AG072449. The authors would also like to acknowledge the support from Google Faculty Research Award and Amazon Web Services Machine Learning for Research Award.

References

  • [1] D. Bellamy, L. Celi, and A. L. Beam (2020) Evaluating progress on machine learning for longitudinal electronic healthcare data. arXiv preprint arXiv:2010.01149. Cited by: §I, §II, §V-D.
  • [2] M. Caron, I. Misra, J. Mairal, P. Goyal, P. Bojanowski, and A. Joulin (2020) Unsupervised learning of visual features by contrasting cluster assignments. arXiv preprint arXiv:2006.09882. Cited by: §I, §II.
  • [3] G. Chechik, V. Sharma, U. Shalit, and S. Bengio (2010)

    Large scale online learning of image similarity through ranking

    .
    Cited by: §I, §II.
  • [4] C. Chen, J. Liang, F. Ma, L. M. Glass, J. Sun, and C. Xiao (2020) UNITE: uncertainty-based health risk prediction leveraging multi-sourced data. arXiv preprint arXiv:2010.11389. Cited by: §I, §II.
  • [5] T. Chen, S. Kornblith, M. Norouzi, and G. Hinton (2020) A simple framework for contrastive learning of visual representations. In International conference on machine learning, pp. 1597–1607. Cited by: §I, §II, §IV-C.
  • [6] D. Chicco (2021) Siamese neural networks: an overview. Artificial Neural Networks, pp. 73–94. Cited by: §II.
  • [7] E. Choi, M. T. Bahadori, J. A. Kulas, A. Schuetz, W. F. Stewart, and J. Sun (2016) Retain: an interpretable predictive model for healthcare using reverse time attention mechanism. arXiv preprint arXiv:1608.05745. Cited by: §I, §II.
  • [8] E. Choi, M. T. Bahadori, L. Song, W. F. Stewart, and J. Sun (2017)

    GRAM: graph-based attention model for healthcare representation learning

    .
    In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, pp. 787–795. Cited by: §II.
  • [9] J. Gao, C. Xiao, Y. Wang, W. Tang, L. M. Glass, and J. Sun (2020) Stagenet: stage-aware neural networks for health risk prediction. In Proceedings of The Web Conference 2020, pp. 530–540. Cited by: §I, §II.
  • [10] B. Gunel, J. Du, A. Conneau, and V. Stoyanov (2020) Supervised contrastive learning for pre-trained language model fine-tuning. arXiv preprint arXiv:2011.01403. Cited by: §I, §II.
  • [11] T. Han, W. Xie, and A. Zisserman (2020) Self-supervised co-training for video representation learning. arXiv preprint arXiv:2010.09709. Cited by: §II.
  • [12] H. Harutyunyan, H. Khachatrian, D. C. Kale, G. Ver Steeg, and A. Galstyan (2019) Multitask learning and benchmarking with clinical time series data. Scientific data 6 (1), pp. 1–18. Cited by: §I, §I, §II, §V-A, §V-A, §V-B, §V-B, TABLE III, TABLE V, §V, §V, §V, footnote 1.
  • [13] A. Jaiswal, A. R. Babu, M. Z. Zadeh, D. Banerjee, and F. Makedon (2021)

    A survey on contrastive self-supervised learning

    .
    Technologies 9 (1), pp. 2. Cited by: §I, §II.
  • [14] A. E. Johnson, T. J. Pollard, L. Shen, H. L. Li-Wei, M. Feng, M. Ghassemi, B. Moody, P. Szolovits, L. A. Celi, and R. G. Mark (2016) MIMIC-iii, a freely accessible critical care database. Scientific data 3 (1), pp. 1–9. Cited by: §I, §I, §II, §V, §V, §V.
  • [15] Y. Kalantidis, M. B. Sariyildiz, N. Pion, P. Weinzaepfel, and D. Larlus (2020) Hard negative mixing for contrastive learning. arXiv preprint arXiv:2010.01028. Cited by: §II.
  • [16] P. Khosla, P. Teterwak, C. Wang, A. Sarna, Y. Tian, P. Isola, A. Maschinot, C. Liu, and D. Krishnan (2020) Supervised contrastive learning. arXiv preprint arXiv:2004.11362. Cited by: §I, §II, §IV-C.
  • [17] D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §V-A.
  • [18] D. Kobak and G. C. Linderman (2021-02) Initialization is critical for preserving global data structure in both t -sne and umap. Nature Biotechnology, pp. 1–2. External Links: ISSN 1546-1696, Document Cited by: §V-D.
  • [19] R. S. Kodialam, R. Boiarsky, and D. Sontag (2020) Deep contextual clinical prediction with reverse distillation. arXiv preprint arXiv:2007.05611. Cited by: §I, §II.
  • [20] P. H. Le-Khac, G. Healy, and A. F. Smeaton (2020) Contrastive representation learning: a framework and review. IEEE Access. Cited by: §I, §II.
  • [21] Y. Li, S. Rao, J. R. A. Solares, A. Hassaine, R. Ramakrishnan, D. Canoy, Y. Zhu, K. Rahimi, and G. Salimi-Khorshidi (2020) BEHRT: transformer for electronic health records. Scientific reports 10 (1), pp. 1–12. Cited by: §I, §II.
  • [22] Y. Li, P. Hu, Z. Liu, D. Peng, J. T. Zhou, and X. Peng (2020) Contrastive clustering. External Links: 2009.09687 Cited by: §II.
  • [23] X. Liu, F. Zhang, Z. Hou, Z. Wang, L. Mian, J. Zhang, and J. Tang (2020) Self-supervised learning: generative or contrastive. arXiv preprint arXiv:2006.08218 1 (2). Cited by: §I, §II.
  • [24] J. Luo, M. Ye, C. Xiao, and F. Ma (2020) HiTANet: hierarchical time-aware attention networks for risk prediction on electronic health records. In Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 647–656. Cited by: §I, §II.
  • [25] L. Ma, J. Gao, Y. Wang, C. Zhang, J. Wang, W. Ruan, W. Tang, X. Gao, and X. Ma (2020)

    Adacare: explainable clinical health status representation learning via scale-adaptive feature extraction and recalibration

    .
    In Proceedings of the AAAI Conference on Artificial Intelligence, Vol. 34, pp. 825–832. Cited by: §I, §II.
  • [26] L. Ma, C. Zhang, Y. Wang, W. Ruan, J. Wang, W. Tang, X. Ma, X. Gao, and J. Gao (2020) Concare: personalized clinical feature embedding via capturing the healthcare context. In Proceedings of the AAAI Conference on Artificial Intelligence, Vol. 34, pp. 833–840. Cited by: §I, §II, §V-A, §V-A, TABLE IV, §V, footnote 1.
  • [27] T. Mikolov, I. Sutskever, K. Chen, G. Corrado, and J. Dean (2013) Distributed representations of words and phrases and their compositionality. arXiv preprint arXiv:1310.4546. Cited by: §II.
  • [28] R. Miotto, F. Wang, S. Wang, X. Jiang, and J. T. Dudley (2018) Deep learning for healthcare: review, opportunities and challenges. Briefings in bioinformatics 19 (6), pp. 1236–1246. Cited by: §I, §II.
  • [29] A. Oellrich, N. Collier, T. Groza, D. Rebholz-Schuhmann, N. Shah, O. Bodenreider, M. R. Boland, I. Georgiev, H. Liu, K. Livingston, et al. (2016) The digital revolution in phenotyping. Briefings in bioinformatics 17 (5), pp. 819–830. Cited by: 2nd item.
  • [30] A. Rajkomar, E. Oren, K. Chen, A. M. Dai, N. Hajaj, M. Hardt, P. J. Liu, X. Liu, J. Marcus, M. Sun, et al. (2018) Scalable and accurate deep learning with electronic health records. NPJ Digital Medicine 1 (1), pp. 1–10. Cited by: §I, §II.
  • [31] scikit-learn.org (2021 (accessed January 29, 2021)) Compute area under the receiver operating characteristic curve (roc auc) from prediction scores.. Note: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html External Links: Link Cited by: §V-B.
  • [32] J. R. A. Solares, F. E. D. Raimondi, Y. Zhu, F. Rahimian, D. Canoy, J. Tran, A. C. P. Gomes, A. H. Payberah, M. Zottoli, M. Nazarzadeh, et al. (2020) Deep learning for electronic health records: a comparative review of multiple deep neural architectures. Journal of biomedical informatics 101, pp. 103337. Cited by: §I, §II.
  • [33] H. Song, D. Rajan, J. Thiagarajan, and A. Spanias (2018) Attend and diagnose: clinical time series analysis using attention models. In Proceedings of the AAAI Conference on Artificial Intelligence, Vol. 32. Cited by: §I, §II.
  • [34] L. Van der Maaten and G. Hinton (2008) Visualizing data using t-sne.. Journal of machine learning research 9 (11). Cited by: §V-D.
  • [35] F. Wang and A. Preininger (2019) AI in health: state of the art, challenges, and future directions. Yearbook of medical informatics 28 (1), pp. 16. Cited by: §I.
  • [36] T. Wanyan, H. Honarvar, S. K. Jaladanki, C. Zang, N. Naik, S. Somani, J. K. De Freitas, I. Paranjpe, A. Vaid, R. Miotto, et al. (2021) Contrastive learning improves critical event prediction in covid-19 patients. arXiv preprint arXiv:2101.04013. Cited by: §I, §II.
  • [37] K. Q. Weinberger and L. K. Saul (2009) Distance metric learning for large margin nearest neighbor classification.. Journal of machine learning research 10 (2). Cited by: §II.
  • [38] Y. Zhang, H. Jiang, Y. Miura, C. D. Manning, and C. P. Langlotz (2020) Contrastive learning of medical visual representations from paired images and text. arXiv preprint arXiv:2010.00747. Cited by: §I, §II.