What does an LSTM look for in classifying heartbeats?

05/23/2017 ∙ by Jos van der Westhuizen, et al. ∙ University of Cambridge 0

Long short-term memory (LSTM) recurrent neural networks are renowned for being uninterpretable "black boxes". In the medical domain where LSTMs have shown promise, this is specifically concerning because it is imperative to understand the decisions made by machine learning models in such acute situations. This study employs techniques used in the convolutional neural network domain to elucidate the inputs that are important when LSTMs classify electrocardiogram signals. Of the various techniques available to determine input feature saliency, it was found that learning an occlusion mask is the most effective.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 7

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Long short-term memory (LSTM) networks have been shown to be effective on medical datasets (Lipton et al., 2015b; Choi et al., 2016a; Rajkomar et al., 2018; Jagannatha & Yu, 2016; Teijeiro et al., 2017; Clifford et al., 2017). There is a snag, however. Researchers struggle to understand exactly how the self-adapted LSTMs do what they do. Hence, two main points drive a deeper understanding of LSTMs. First, clinicians need to be able to explain how a model’s decisions are based on a patient’s data. Second, models that achieve breakthrough performance may have identified patterns in the data that practitioners would like to understand.

Previous work has provided insights into the operation of the LSTM (Karpathy et al., 2015; Li et al., 2016)

, however, this was focussed on the discrete-valued sequences of natural language processing tasks. Conversely, our work explores visualization techniques for understanding LSTMs applied to continuous-valued sequences, specifically, single-lead electrocardiograms (ECGs). We learn from recently proposed visualization techniques, which we describe next.

2 Related work

A common form of interpretability for deep learning models is through attention mechanisms

(Xu et al., 2015; Bahdanau et al., 2015; Deming et al., 2016; Rajkomar et al., 2018; Ching et al., 2018). By revealing which input features are used for different outputs, the attention mechanisms provide insights into the model’s decision-making process. In the clinical domain, Choi et al. (2016b) leveraged attention mechanisms to highlight which aspects of a patient’s medical history were most relevant for making diagnoses. Choi et al. (2017) later extended this work to take into account the structure of disease ontologies and found that the concepts represented by the model aligned with medical knowledge – an exercise we strive to imitate. However, interpretation strategies that rely on attention mechanisms do not provide insight into the logic used by the attention layer.

Backpropagation-based methods are also popular for interpreting deep learning models. They have the signal from a target output neuron backpropagated to the input layer. The simplest form of this was proposed by Simonyan et al. (2013), and many proposals since have improved utility (Bach et al., 2015; Kindermans et al., 2016; Springenberg et al., 2014; Mahendran & Vedaldi, 2016). However, as with the attention mechanisms and deconvolution (Zeiler & Fergus, 2014), many of these visualization techniques require architectural modifications (Fong & Vedaldi, 2017).

On the other hand, perturbation-based interpretation approaches do not require changes in the model architecture; instead, they change parts of the input and observe the impact on the output of the network. These include visualizing the drop in classification score as constant value masks are applied to different input patches of images (Zeiler & Fergus, 2014). A recent study by Fong & Vedaldi (2017) proposed to use gradients to learn the minimal input “deletions” that minimize the class score. This technique requires access to only the model’s inputs and outputs and provides aesthetically pleasing explanations of image input salience. We modify this technique to visualize continuous inputs that are analysed by LSTMs.

Visualizations of recurrent neural networks (RNNs) have successfully been explored in natural language processing (Li et al., 2016; Karpathy et al., 2015). In both these studies, the aim was to improve the understanding of RNNs by leveraging human knowledge of the structured language. Such well-understood discrete input lends itself to interpretable isolated explanations for the importance of each input to the model. In the medical paradigm, firstly, we are less certain about the underlying biological processes that govern the generation of physiological signals, and secondly, due to inputs being continuous, isolation of salient features is more difficult. RNNs have also been visualized by Lanchantin et al. (2017) in the domain of DNA sequences. Here again, the DNA inputs are discrete compared with the continuous-valued medical time series that we analyse.

3 Visualization techniques

Before describing the four visualization techniques explored, some preliminaries are required; each input sequence with time steps has a corresponding class label . An LSTM provides pre-softmax activations for each input sequence, with denoting the score for the correct class of the input sequence, and the number of classes.

3.1 Temporal output score

A first approach to understanding LSTM classifications is to illustrate the progression of model decisions over time for a specific input sequence; i.e., incrementally longer subsequences of the original time series are classified and visualized. More formally, for a sequence with length we classify the prefix111We use the term prefix to refer to a subsequence at the start of the original sequence. of the original sequence as ranges from 1 to (Lanchantin et al., 2017)

. Fortunately, the LSTM outputs a classification at each time step and we can simply record the output vector given by

at each time step to obtain the temporal output scores, where is the hidden output of the LSTM (Lipton et al., 2015a) with units, and is the output weight matrix. Note that nothing changes during training where the single label of an input sequence is compared to the classification only at the final time step .

These predictions at each time step can then be superimposed onto the original signal, either as colour-coded predicted categories or as the probability of being in the correct class. For the latter, we record the element in the softmax vector corresponding to the true label of the input sequence. As we show later, this temporal output score technique is limited by its cumulative nature; it provides the tipping points during model decision making, but it does not indicate which input features were most salient.

3.2 Input derivatives

Owing to the LSTM being highly nonlinear, it is difficult to determine the influence of each input x on . We can, however, approximate with a linear function in the neighbourhood of a specific input sequence by using the first-order Taylor expansion222 is assumed to be smooth at .

This linear formulation allows us to see that the magnitude of the elements in determines the importance of the corresponding elements in x. Visualizing these magnitudes enables interpretation of the model decisions. The gradient can, fortunately, be computed with a single step of backpropagation. However, in section 4.1.2 we show that this is a bad approximation for LSTMs for reasons explained in Fong & Vedaldi (2017). Instead of the arbitrary variations explored with this approach, we could follow a more controlled method – perturbing the input by deleting (occluding) subsections.

3.3 Occlusion

Zeiler & Fergus (2014) proposed iteratively occluding regions of an image in order to visualize which regions are salient. We apply this to the sequential paradigm of LSTMs by occluding subsequences of an input sequence with a predetermined occlusion value.

This technique iteratively occludes subsequence of by some constant as ranges from 1 to , where denotes the width of the occlusion. At the start and end of the sequence, the occlusion is shortened such that each element of is occluded the same number of times. Over these iterations, a new sequence accumulates the class score of a sequence in element if element is not occluded. If an input is more important for classification, then its corresponding sum of class scores will be high. For visualization, the sequence is scaled to range from zero to one and superimposed onto the original input sequence. For images processed in scanline order, we can use more creative occlusions, such as a pixel block. For pixel images, this tailored filter would occlude 5 contiguous pixels every 28 time steps, repeated 5 times.

The manual nature of the occlusion approach limits its efficacy. First, the correct occlusion value that resembles a “deletion” of information is difficult to determine. Second, the width and the structure of the occluded regions have to be specified and are fixed, leaving endless possible combinations unexplored. Next, we introduce a method that mitigates these issues by employing gradients to efficiently explore the space of possible input perturbations.

3.4 Learning an input mask

With focus on convolutional neural network (CNN) applications, Fong & Vedaldi (2017) recently proposed a method for learning a mask that optimally deletes information in an image. Here we describe an adaptation that makes their proposal suitable for LSTMs applied to physiological signals.

The aim is to find a mask that minimally deletes information in the input sequence while minimizing the class score . With denoting the number of inputs per time step , we have . The function that masks each element of the input is defined as

(1)

where is a constant resembling a deletion of information. We seek a mask with elements that are near binary, either deleting completely or not at all, and with as few zero-elements as possible. The norm provides a suitable technique to encourage this property, but also provides enticing conditions for the mask to trigger artefacts in the LSTM (Fong & Vedaldi, 2017). To prevent artefacts, the total variation of is added to the objective function, which encourages the mask to be smooth. The resulting objective function is given by

(2)

where and weight our preference for zero-valued mask elements and mask smoothness. For visualization the learned mask is scaled to range from zero to one, subtracted from one, and superimposed onto the original input.

Fong & Vedaldi (2017) additionally upsample the input images and jitter them to further mitigate the triggering of artefacts. Essentially, this jitter randomly translates the input by a few elements. By minimizing the expected class score over these random translations, the mask is required to have a similar effect on neighbouring elements. In our experiments, we did not find signs of artefacts and applying jitter to the input resulted in larger-than-required masks.

All of our experiments used Adam (Kingma & Ba, 2015) with a learning rate of 0.001 to optimize the mask over 500 iterations. Among initial values of 0, 0.5, and 1 for the mask, the latter convincingly yielded the best results. We experimented with various values for and . Setting provided good results over all the datasets analysed and values of are shown in section 4.1.2

. Learning an input mask is insensitive to these hyperparameters; changing them changes the size of the important regions but the important features remain the same.

4 Results

In order to apply these visualization techniques, we require a trained LSTM. Thus two LSTM classifiers are trained, each on a different dataset, using the cross-entropy objective function. The datasets comprise the MNIST dataset (LeCun, 1998), processed in scanline order, and the MIT-BIH arrhythmia dataset (Moody & Mark, 2001). Similar to medical time series, the MNIST dataset, processed in scanline order, consists of continuous-valued sequences, which enables intuitive evaluation of the visualization techniques.

For the MIT-BIH dataset, single heartbeats were extracted from longer filtered signals on channel 1 by means of the BioSPPy package (Carreiras et al., 2015). The signals were filtered using a bandpass FIR filter between 3 Hz and 45 Hz. The Hamilton QRS detector (Hamilton, 2002) was used to detect and segment single heartbeats. We chose the four heartbeat classes that are best represented over different patients in the dataset: normal, right bundle branch block (RBBB), paced, and premature ventricular contraction (PVC). This resulted in 89,670 heartbeats from 47 patients. We randomly split the data over patients to have heartbeats from 33 train-, 5 validation-, and 9 test-patients (70:10:20). An acceptable split was considered to have all classes in each set contain at least smallest-class-size data points, where is the split-fraction (0.7, 0.1, or 0.2). For MNIST the standard data split was used and all the examples shown in this section are from test sets.

The LSTMs, without peepholes, are based on Lipton et al. (2015a)

and were trained with Adam at a learning rate of 0.001, a dropout probability of 0.1, and a minibatch size of 200. Optimization was run for 100 epochs and the validation loss was used to determine the best model. With 2 layers of 128 hidden units, an accuracy of 98.7% was achieved on MNIST, similar to

Arjovsky et al. (2016) and Cooijmans et al. (2016). With a single layer of 128 units and 0.001 weight decay, an accuracy of 80.3% was achieved on the MIT-BIH Arrhythmia dataset. We manually explored hyperparameter values to find adequate results.

4.1 Qualitative evaluation

4.1.1 Temporal output scores

In figure 1 we illustrate the temporal output scores for the MNIST and MIT-BIH datasets. For both datasets, the probability of being in the correct class, as well as the most probable class, is displayed at each time step.

Figure 1: Temporal output scores for the MNIST dataset (left) and the MIT-BIH dataset (right). The probability that the LSTM assigns to the true class is shown, along with the predicted most likely class. For the MIT-BIH dataset, the 4 heartbeat classes are normal, right bundle branch block (RBBB), paced, and premature ventricular contraction (PVC). The ECG signals are plotted with the same y-scale. Note, the model incorrectly classified the nine-digit example and all ECG examples were correctly classified. (Best viewed electronically.)

Usefully, this technique allows tracking of LSTM decisions over time. Considering the most likely class on MNIST, the digits are all classified as nine for the first six time steps before the prediction switches to seven – an interesting prior learned by the model. The one-digit marginally constitutes the largest class of the MNIST training and validation datasets, with seven being the close second. We speculate that the model is quicker to identify a seven (it has a long flat region at the top) than a one, which could explain the learned prior. With such balanced datasets, the prior could also depend on the order of the data points seen during the final epoch of training. On the other hand, the normal class of the MIT-BIH dataset is much larger than the other categories, which explains why the model initially assumes all ECGs to be normal.

A neat feature of this technique, not provided by any of the other visualization techniques explored, is the ability to determine the minimum sequence length required for classification. Compare, for example, classification of the five-digit, which is confidently correct from halfway through the sequence, to the classification of the four-digit, which switches to the correct class at the very last moment – it would be detrimental to shorten the MNIST digit sequences. Similarly, for the MIT-BIH dataset, the LSTM switches to the correct class at the very last moment for all of the examples.

Although this is a fun and easy-to-implement technique, it is limited to cumulative sequential explanations. For example, in the RBBB heartbeat example, what the model looks at is not apparent. The technique depicts when the predictions are changed, but not the extent to which each time step contributed to the prediction.

4.1.2 Input feature salience

In this section, the input derivative, occlusion, and mask learning visualization techniques (sections 3.2 to 3.4) are visually compared. For the deletion techniques we experimented with 0, 0.5, and the mean for values of . For the occlusion width , we explored values of 2, 5, 10, 15, 20, 25, and 50 on both datasets. Of the various hyperparameters explored for each visualization technique, we present the set of most informative results.

We start the comparison on the easy-to-understand MNIST dataset in figure 2. The standard occlusion technique successfully finds interpretable salient features, however, when the occlusion width , many seemingly important features go unnoticed, for example, the seven-digit in row 2. On the other hand, setting results in overly elongated salient regions and exemplifies the difficulty of hyperparameter selection for this technique.

Figure 2: Input feature salience for examples of the MNIST dataset. The different features that are important according to the occlusion, mask, or input derivative techniques, are displayed on a scale of 0 to 1, where 1 is important. Each column shows the analysis of the same, correctly classified, original input. The occlusion width is denoted by and the deletion value by . The block refers to a carefully structured occlusion, which is the equivalent of a pixel block being occluded for each iteration (see section 3.3). Most of the techniques find some interpretable salient features, with the mask learning technique, row 5, producing the best results.

In the domain of MNIST digits, setting the deletion value is sensible because, given the digit strokes, the background is unimportant and zero-valued. Nevertheless, occluding with yields interpretability by producing the negative of the original input, and the most effective occlusion with is achieved when using a carefully structured block occlusion (row 4). When , however, the block occlusion is less effective than the standard occlusion.

Evidently, learning a mask is the most effective salience technique for this task (row 5). Setting the deletion value results in a mask that covers most of the digit stroke patterns, which is intuitively the best option. Fong & Vedaldi (2017) use the mean of the input as the deletion value, which in our case yields less interpretable results. It would seem that the selection of the deletion value is task-specific.

With a better understanding of how deletion values and occlusion widths influence visualizations, we proceed to visualize salient features for the LSTM on the MIT-BIH dataset. A standard ECG with annotated cardiologist interest points, P, Q, R, S, and T, is shown in figure 3. With occasional reference to these interest points, we discuss the comparison of the different visualization techniques on the MIT-BIH dataset, as illustrated in figure 4.

Figure 3: A standard single-lead ECG with the interest points labelled.
Figure 4: Input feature salience for ECG signals of the MIT-BIH dataset. The input feature importance for LSTM classification, as given by the occlusion, mask, and input derivative techniques, are displayed on a scale of 0 to 1, where 1 is important. The deletion value is denoted by , and the occlusion width by . On the left-hand side, the true label of the ECG is indicated, with the five classes of heartbeats being normal, right bundle branch block (RBBB), paced, and premature ventricular contraction (PVC). Where the LSTM incorrectly classified an input, the correct class is indicated by *. All the signals are displayed with the same y-scale and have a length of 216 time steps (x-axis).

At first glance, it’s evident that the occlusion and mask techniques have some overlap (rows 2, 3, 4, 7, and 8). However, later in our discussion, it becomes apparent that the additional regions explained by the learned mask render it superior to the other techniques. Over different examples in each class, similar features are found salient by the mask and occlusion techniques. Furthermore, as expected in clinical practice (Moody & Mark, 2001), the ECG leads varied among subjects (see the difference between rows 1 and 2), which makes generalization harder for LSTMs.

We consulted a cardiologist to help identify the salient features detected by LSTMs that align with medical theory. The analysis is summarized as follows:

  • We begin with the normal heartbeat; the model correctly identifies the QRS complex as important, with the learned mask indicating that more importance is placed on the R-peak and S-wave (row 2). Because the signal in row 1 is from a different ECG lead, it has a low R-peak, which given the importance of the R-peak for normal heartbeats could explain the misclassification. As found in practice, the model finds the Q-wave to be less helpful.

  • A wide S-wave is usually seen in right bundle branch block (RBBB) heartbeats, which was correctly identified as salient by the model – highlighted by the mask and occlusion techniques for both RBBB examples (rows 3 and 4). The learned mask additionally shows that the LSTM correctly identifies the extra bump leading up to the R-peak (a characteristic of an RSR pattern) to be important for classifying heartbeats. Note that RBBB heartbeats can look similar to LBBB heartbeats depending on the ECG lead.

  • Depending on where the pacing leads are placed within the heart, the R-wave (or sometimes Q-wave) follows the pacing spike (see rows 5 and 6). Thus identifying the narrow upstroke of the pacing spike could provide a means of detecting paced heartbeats. According to the learned mask, the LSTM learns to identify this pacing spike, whereas the occlusion technique seems to find the R-to-S transition important.

  • Lastly, in rows 7 and 8, the learned mask shows that the model considers the ratio of the R-peak and S-wave as an important feature of premature ventricular contraction. Medically this is relevant for some ECG leads, with the S-wave being deep relativeto the R-peak. In practice, the duration of the QRS complex is primarily used to determine premature ventricular contraction, but it’s difficult to justify whether this is something the model finds salient.

The analysis demonstrates that the input features considered salient align well with medical theory. Note, however, cardiologists usually classify heartbeat arrhythmias by means of multi-lead ECG signals and take the current patient health status (e.g., chest pain) into account. Such additional inputs could improve LSTM performance.

A visualization technique that has not been given much attention thus far is the input derivative. On both of the before mentioned datasets, this technique yields the least interpretable salient features. An additional advantage of the occlusion and mask techniques is to determine the effect of unwanted perturbations on classification, such as missing values in physiological signals.

We observed the visualization techniques on a significant portion of the datasets and found the salient regions to be consistent for each class – the examples presented in this section represent their entire corresponding dataset. The following section describes a quantitative effort to find a measure of efficacy for the entire test dataset.

4.2 Quantitative evaluation

In this section, we compare the efficacy of the input feature salience techniques by calculating how much they reduce the class score on average. Here the total size of the salient regions is not a concern, instead, the important input regions should be as relevant as possible regardless of their size. To compare the different techniques, however, we need to penalize larger deletion areas because simply deleting the whole input could yield the largest score reduction. Therefore, we scale the score reduction for each input sequence by the ratio , where is the number of input elements occluded of the sequence , and is the total number of time steps. Before occluding a sequence and computing the reduced class score, a threshold denoted by has to be specified, above which, the input features are considered important enough to occlude.

Note, this metric does not consider whether the occlusions are correct, for example, the ECGs could be occluded at clinically irrelevant regions. It solely provides a measure of how efficiently each technique finds salient input features for the LSTM. Hence, establishing the utility of the visualization techniques still requires a qualitative evaluation. To have the average class score reductions be related to the utility of the salience technique, we select the hyperparameters based on visual analysis. We set for all techniques and use for MNIST occlusion and for MIT-BIH occlusion.

In figure 5 we illustrate the average score reductions over a range of values for . Evidently, the learned mask most efficiently deletes information from the input. In contrast to the other techniques, when , the learned mask still greatly reduces the class score, meaning that the most salient features are truly important.

Figure 5: Average class score reductions for the three input feature salience techniques. The y-scale is the same for each graph. Learning a mask is clearly the superior technique.

5 Discussion and conclusion

This paper deals with the important problem of providing insights into how black box models, specifically LSTMs, make their decisions. The premise is simple – for many practical applications, one needs justification of the decision, instead of a bare prediction. LSTMs and other neural networks do not provide this kind of information. When these models are applied to medical data, understanding why they make certain decisions would allow clinicians to build better trust in them and could provide novel insights on medical phenomena.

This work goes some way to improving our understanding of LSTMs. We compared the efficacy of four visualization techniques for the LSTM. In this comparison, it is argued that explaining what an LSTM focusses on depends in large part on the meaning of varying the input to the model. It was found that learning a deletion mask yields the most interpretable results and the largest reduction of the class score. Performance of this technique seemed fairly unaffected by different deletion values, but this could be data-dependent. Furthermore, we found that the ECG input features considered salient by LSTM align well with medical theory.

While performing experiments we also investigated class mode visualization (Simonyan et al., 2013), but the results were completely uninterpretable. We have reported results for only univariate inputs. During our analysis, visualization techniques were applied to multivariate physiological signals from intensive care unit patients and traumatic brain injury patients. In such multivariate scenarios, the visualizations become convoluted and lose their interpretability. Understanding what LSTMs look for in multivariate signals thus remains an open problem.

Acknowledgements

We thank Vadir Baktash and Steve Foulkes for useful discussions on medical theory.

References