A Generative Modeling Approach to Limited Channel ECG Classification

02/18/2018 ∙ by Deepta Rajan, et al. ∙ 0

Processing temporal sequences is central to a variety of applications in health care, and in particular multi-channel Electrocardiogram (ECG) is a highly prevalent diagnostic modality that relies on robust sequence modeling. While Recurrent Neural Networks (RNNs) have led to significant advances in automated diagnosis with time-series data, they perform poorly when models are trained using a limited set of channels. A crucial limitation of existing solutions is that they rely solely on discriminative models, which tend to generalize poorly in such scenarios. In order to combat this limitation, we develop a generative modeling approach to limited channel ECG classification. This approach first uses a Seq2Seq model to implicitly generate the missing channel information, and then uses the latent representation to perform the actual supervisory task. This decoupling enables the use of unsupervised data and also provides highly robust metric spaces for subsequent discriminative learning. Our experiments with the Physionet dataset clearly evidence the effectiveness of our approach over standard RNNs in disease prediction.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

With the unprecedented success of machine learning in solving challenging problems across multiple domains, there is increasing interest in leveraging state-of-the art techniques to applications in health care. The community-wide efforts for creating large-scale benchmark repositories, such as MIMIC-III and Physionet CinC challenge

[1]

, have accelerated machine learning research in health care. Furthermore, with increased adoption of automated systems for disease diagnosis, there is a huge opportunity for building robust data-driven solutions that can alleviate pain-points within clinical workflows. Broadly, careful modeling of health care data requires tackling inherent challenges including multi-variate measurements, long-range temporal dependencies, and missing information in order to make precise predictions. Despite the success of hand-engineered features in clinical models, more recently, regularized representation learning techniques, such as sparse and deep learning, have been more effective. A thorough experimental study on UCR time-series datasets revealed that simple deep learning architectures using 1-D Convolutional Neural Networks (CNNs) can easily outperform traditional task-specific models built on hand-engineered features

[2]

. More recently, Recurrent Neural Networks (RNN) based on Long Short Term Memory (LSTM) units have become the de-facto solution for clinical time-series analysis. Choi et. al

[3]

demonstrated the effectiveness of LSTMs to deal with clinical sequence classification. The state-of-the-art Intensive Care Unit (ICU) sequence modeling architectures are based on deep attention models that solve multiple prediction tasks jointly

[4].

I-a Problem Statement

In this work, we investigate classification of multi-channel Electrocardiograms (ECG) measurements, one of the most common diagnostic modalities. The overall goal is to learn features from the multi-variate sequences that can potentially help diagnose heart conditions such as Myocardial Infarction, Branch Bundle Blocks, Cardiomyopathy, and several others [5]. Generally speaking, to accurately identify if an ECG depicts a normal sinus rhythm or an abnormal rhythm requires detection of various wave segments (P-wave, QRS-complex, T-wave) and understanding their complex morphological relationships over time. This motivated the design of classical signal processing approaches such as the Pan-Tomkins algorithm [6] and wavelet analysis [7]. However, this task remains challenging due to its episodic nature, inherent measurement noise, inter-patient variability of wave patterns, and ambiguity of labels etc. Consequently, neural network based solutions have been developed for this problem, in particular for detecting Myocardial Infarction. Similar to [8, 9], in this paper, we consider an additional challenge that it is necessary to perform predictions using a limited channel configuration at test time. We show that, RNNs are plagued by overfitting, and they produce inferior predictions under the limited channel setting.

Fig. 1:

Proposed approach for limited channel ECG classification. In the first stage, we adopt a unsupervised, generative modeling approach to construct the latent space, while in the second stage we aggregate the temporal features from the latent space using dense interpolation and use them to train a random forest classifier.

I-B Proposed Work

Inspired by the recent surge in the success of generative models for image [10] and text classification [11], we propose a generative modeling approach to limited channel ECG classification. Unsupervised generative models, e.g. variational auto-encoders [12] and Seq2Seq models [13], enable the inference of latent features that can effectively describe the complex distribution of data. In our approach, we utilize Seq2Seq

modeling to construct latent spaces that can predict the entire multi-channel (12-channel in our case) ECG measurements using only the limited channel data. The resulting latent feature representations implicitly exploit information from the missing channels and can generalize better for limited channel measurements at test time. In the classifier design stage, the latent representations are aggregated into feature vectors using a dense interpolation scheme

[4] and are subsequently used to train a simple random forest classifier. Our experiments with the PhysioNet PTB dataset [5] show that, the proposed generative approach is significantly superior to the conventional discriminative architecture, even in the 12-channel case, while remaining robust to limited channel conditions. Given the critical need to make accurate predictions from limited measurements (single-channel ECG) recorded by mobile ECG monitors in practical scenarios [14], [1], our generative modeling based approach can generalize well to different datasets and measurement configurations.

Ii Proposed Approach

Discriminative models have been commonly employed for a variety of multivariate clinical time-series classification tasks. In particular, sequence modeling techniques based on Recurrent Neural Networks (RNNs) have achieved state-of-the-art results in multi-channel ECG classification. Despite their widespread use, we show that discriminative architectures, which jointly learn the latent features and the classifier, can be non-robust when dealing with datasets characterized by imbalanced class distributions. In such scenarios, the models are often plagued by overfitting, and can hence produce highly unreliable predictions. Furthermore, in cases where we consider only a subset of channels for the actual prediction, RNNs based on LSTM units (Long-Short Term Memory) perform very poorly, thus motivating the need to design a robust prediction pipeline. In this section, we describe a novel, generative modeling approach to multi-variate sequence classification, which can handle both imbalanced class distributions and limited channel scenarios.

Ii-a Overview

Figure 1 illustrates an overview of the proposed approach for limited channel ECG classification. The key idea of our approach is to employ a generative architecture to model the multi-variate sequences, prior to learning the classifier. Broadly referred to as classification by synthesis, such an approach has been adopted in image classification, particularly to deal with out-of-distribution samples at test time [15]. In this paper, we argue that using a generative approach can produce models that are highly robust to using partial measurements for prediction. In a nutshell, the proposed approach is comprised of two stages: (i) unsupervised generative modeling stage where we utilize a Seq2Seq architecture to construct a latent space that is the most effective for predicting multi-channel ECG data using only a partial set of channels; (ii) supervised modeling stage that builds a random forest classifier using latent features from Stage 1. Note that, in order to aggregate features from different time steps, while preserving partial order, we propose to employ a dense interpolation strategy, similar to [4]. At the testing stage, the predictions are made solely based on the partial set of channels that was used to train the model.

Ii-B Stage 1: Generative Models for Limited Channel Data

As shown in Figure 1, we assume access to the entire multi-channel ECG measurements for the training data, and for a given limited channel configuration, we enable predictions using only the subset of channels at test time. Denoting the multi-variate sequence dataset as , where denotes the number of training samples, denotes the number of time-steps in each measurements and indicates the total number of channels. The channel configuration is typically determined based on the disease to be diagnosed, for example, the channels {II, III, aVF} are known to be essential for detecting Myocardial Infarction [9]. Denoting the set of limited channels by , whose cardinality , we extract the matrix . In order to perform implicit completion of the missing data, we propose to build a generative model that attempts to recover using . In this process, it infers a latent space that defines an effective metric to compare different samples.

More specifically, we build an encoder-decoder architecture, commonly referred as Seq2Seq [13], with an optional attention mechanism. Though originally developed for machine translation, they are applicable to more general sequence to sequence transformation tasks. The architecture is comprised of two RNNs (based on LSTM), one each for encoder and decoder. The encoder transforms an input sequence from

into a fixed length vector, either from the last time step of the sequence or by concatenating hidden representations from all time steps. The decoder then predicts the output sequence, in our case

, using the encoder output. Optionally, the decoder can also attend to a certain part of the encoder states through an attention mechanism. The attention mechanism often uses both content from the encoder states, and also context from the sequence generated so far at the decoder. Our RNNs are designed using Long Short Term Memory units, which are capable of learning long-term dependencies. Each LSTM cell is comprised of the following operations, implemented using fully connected networks:

(1)

The LSTM has the ability to remove or add information to the cell state, carefully regulated by structures referred to as gates. While the forget gate updates the cell states by determining which information to ignore based on context, the input gate determines which information needs to be updated to the cell state based on the previous hidden state and content at the current time step. Finally the output gate produces a filtered version of the cell state, based on the context and previous hidden state. The generative model is trained with an loss at the decoder output. Note that, our architecture attempts to reconstruct the input channels as well as predict the missing channel measurements.

Ii-C Stage 2: Classifier Design

We now design a classifier stage that exploits the latent space from the generative model trained for missing channel prediction. Interestingly, compared to discriminative models, this approach utilizes additional channel information from the training stage and builds a more effective metric for the whole data space instead of discriminating the normal/abnormal classes. Furthermore, since the first stage is unsupervised, we can use even unlabeled data to construct a more robust latent space.

Ii-C1 Dense Interpolation Embedding

For given limited channel sequences, , the encoder returns the latent features , where is the number of hidden dimensions in the Seq2Seq model. The simplest approach to obtain an aggregated representation for this feature, while preserving order, is to simply concatenate embeddings at every time step. However, in our case, this can lead to a very high-dimensional, “cursed” representation which is not suitable for learning and inference. Consequently, we propose to utilize a dense interpolation strategy similar to [4]. Denoting the hidden representation at time as , the interpolated embedding vector will have dimension , where is the dense interpolation factor. Note that when , it reduces to the concatenation case. The main idea of this scheme is to determine weights , denoting the contribution of to the position of the final vector representation .

Ii-C2 Classifier

Using the embeddings from dense interpolation, , we build a random forest classifier to predict the labels . At test time, each limited channel sequence is passed through the encoder of the generative model to obtain the latent representation and subsequently processed using the dense interpolation strategy and the random forest classifier.

Iii Experiments

In this section, we describe the dataset used in the experiments, and evaluate our proposed generative modeling framework on the task of ECG classification. In addition, we discuss the evaluation metrics, and the choices for the hyperparameters. Finally, we elaborate on the disease-based limited channel configurations selected for experimentation and highlight the comparative performance of standard RNNs against the proposed approach.

The Physionet PTBDB [5] is comprised of ECGs collected from healthy volunteers as well as patients with a wide range of heart diseases. It has a total of records from subjects, with each record containing the standard 12-leads along with 3 Frank lead ECGs (vx, vy, vz). A single raw ECG record is about 30 seconds in duration with measurements sampled at 1000 Hz. As part of data-preprocessing, we normalized the data, defined an ECG frame to be a 5-seconds long, with a 2-second overlap between frames and finally down-sampled them to time-steps.

Channels Standard RNN Proposed
V1 to V6 0.78 0.83
V1, V2, V3 0.60 0.69
II, III, aVF 0.56 0.60
II, III, V3 0.55 0.69
V1, V6 0.53 0.74
II, V1 0.52 0.59
II 0.51 0.62
TABLE I: AUROC scores of normal/abnormal binary classification task for different channel configurations. We report the performance of standard RNNs and the proposed approach. In each case, the best performance is marked in bold.
Disease Channels Acc. Sens. Spec.
Myocardial Infarction V1, V2, V3 86 96 84
Bundle Branch Block V1, V6 94 97 99
TABLE II: Performance of disease-specific models trained using appropriate channel subsets with the proposed approach.

Iii-a Setup and Evaluation Metrics

Iii-A1 Task I: Normal/Abnormal Classification

Although PTBDB is a multi-class dataset, containing distinct abnormal conditions, a miscellaneous class and a healthy control group, in this task, we formulate the problem into a two-class problem of separating normal and abnormal ECGs by grouping all disease classes into a single abnormal class. This makes the problem highly ill-posed, challenging a classifier to ignore the large variations inside the abnormal class while discriminating from the normal class. Problems using clinical data are commonly solved with such a normal versus abnormal formulation, with an optional second stage classification to determine the actual type of abnormality. However, in practice, this formulation makes it very challenging for discriminative models (e.g. standard RNNs). In contrast, our approach exploits additional data at the training stage in the form of missing channel information and consequently produces more reliable latent spaces. As we will show in our results, this two stage approach is significantly superior to standard RNNs, even when we reduce the number of channels drastically (). In this task, we used frames for training and frames for testing. For this task, we use a popular summary statistic metric - micro-averaged AUROC, which computes the area under the ROC curve for both classes together.

Iii-A2 Task II: Disease-specific Prediction

In this task, we consider a subset of the data corresponding to specific abnormalities, Myocardial Infarction and Bundle Branch Block, and build a predictive model to detect that condition. For both cases, we used channel subsets that are commonly used in clinical settings for prediction of those diseases (Table II). In the case of Myocardial Infarction, we used frames for training and for testing. Whereas, for Branch Bundle Blocks, we used frames for training and for testing. Following common practice, we use three metrics: Accuracy, Sensitivity and Specificity.

For both standard RNN and the proposed approach, we considered -layer LSTM networks with . The dense interpolation factor was set at , thus producing feature vectors of dimension . All models were trained using Adam optimizer with learning rate .

Iii-B Results

Table I reports the results for Task I obtained using standard RNNs and the proposed approach, under different channel configuration settings. As discussed earlier, when the entire 12-lead data is not available at test time, it is important for models to generalize to limited channel scenarios. As it can be observed from the results, the proposed approach is significantly better than standard RNNs for all cases, with the AUROC metric indicating its reliability in predicting both normal and abnormal class samples. For example, in the case of using channels to , the proposed approach achieves an AUROC of , which is superior than achieved by standard RNNs with similar number of parameters. Interestingly, even in the case of a single channel, the proposed approach is more effective in dealing with the imbalance in label distribution, providing highly robust predictions in comparison to discriminative models.

The results for disease-specific classification models are shown in Table II. Though not reported, we found similar improvements over standard RNNs as observed in the previous task. For example, in the case of Myocardial Infarction detection, the proposed approach achieves an accuracy of with sensitivity and specificity of and respectively. Note that, the accuracy and specificity scores are similar to the state-of-the-art reported in [9], while the sensitivity is superior by . We observe similar results in other channel configuration and disease predictions.

In summary, the generative modeling based two-stage architecture produces highly accurate models in limited channel settings, even with datasets that are distributed non-uniformly. This motivates the use of Seq2Seq models in clinical modeling pipelines to both exploit unsupervised data and build more meaningful metrics, that can in turn lead to better supervisory models.

Iv Acknowledgments

We thank Dr. Girish Narayan (cardiologist) for lending his expertise in the experimental design. We also thank Rushil Anirudh for helping us with the data preparation. This work was performed under the auspices of the U.S. Dept. of Energy by Lawrence Livermore National Laboratory under Contract DE-AC52-07NA27344.

References