Co-Attentive Cross-Modal Deep Learning for Medical Evidence Synthesis and Decision Making

09/13/2019 ∙ by Devin Taylor, et al. ∙ University of Cambridge 0

Modern medicine requires generalised approaches to the synthesis and integration of multimodal data, often at different biological scales, that can be applied to a variety of evidence structures, such as complex disease analyses and epidemiological models. However, current methods are either slow and expensive, or ineffective due to the inability to model the complex relationships between data modes which differ in scale and format. We address these issues by proposing a cross-modal deep learning architecture and co-attention mechanism to accurately model the relationships between the different data modes, while further reducing patient diagnosis time. Differentiating Parkinson's Disease (PD) patients from healthy patients forms the basis of the evaluation. The model outperforms the previous state-of-the-art unimodal analysis by 2.35 parameter efficient than the industry standard cross-modal model. Furthermore, the evaluation of the attention coefficients allows for qualitative insights to be obtained. Through the coupling with bioinformatics, a novel link between the interferon-gamma-mediated pathway, DNA methylation and PD was identified. We believe that our approach is general and could optimise the process of medical evidence synthesis and decision making in an actionable way.



There are no comments yet.


page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

One of the biggest challenges facing the medical field is the ability to understand and diagnose complex diseases. Complex diseases are diseases which are caused by a combination of genetic, environmental and lifestyle factors craig2008complex. Understanding the influence of each of these factors requires the study of biomedical data which can exist in drastically different biological scales and formats. Complex disease analyses primarily focus on unimodal data and do not necessarily consider the interdependence

between the different modes of data available. This work aims to exploit these relationships through the use of Machine Learning (ML) to diagnose, and gain data-driven insights into, complex diseases.

Cross-modal deep learning is a machine learning technique that utilises information obtained from multiple different data sources, which all independently encode information about a specific outcome, to obtain a greater understanding of the problem domain. Cross-modal models commonly use specialist subnetworks to independently extract features from the different data modes, before concatenating the features together to perform a final prediction. This approach has been successfully used in audio-visual speech classification ngiam2011multimodal, object recognition eitel2015multimodal, and cancer subtype classification liang2015integrative

. However, the shortfall of this approach is understanding complex relationships between data modes. This is due to the dependence on a final feed-forward neural network to model all relationships between the independently extracted features in the concatenated latent space.

This work leverages off of cutting-edge research in Neural Machine Translation (NMT) and Visual Question Answering (VQA) to propose a cross-modal model which focuses on better modelling these complex relationships. Specifically, the model adapts the concepts of

co-attention from VQA lu2016hierarchical to multi-head self-attention from NMT vaswani2017attention to present the Multi-Head Co-Attention (MHCA) model. Attention has previously shown benefits in cross-modal learning for graph structured data deac2019attentive. The architecture proposed in this research places an emphasis on learning a joint similarity distribution between the data modes, thus providing a generalised solution able to account for biomedical data from different biological scales.

Observing the attention weights further enables qualitative validation of the model predictions. Interpretability is a dominant theme in ML for health cleophas2015machine; cabitza2017unintended. Existing solutions focus on unimodal models and have separate modules running in parallel to the model which monitor the model properties to produce statistically-based interpretations ribeiro2016should; luo2016automatically; lahav2018interpretable. Whereas, the solution presented in this work has the interpretability as an integral component of the model architecture.

Parkinson’s Disease (PD) is a complex neurodegenerative disorder which primarily affects the motor functionality of a patient lang1998parkinson. The cause of the disease is not yet well understood, resulting in ineffective preventative and treatment measures. This has seen the incident and death rate of PD increase year-on-year dorsey2018global. Consequently, PD is one of the most active areas of medical research. There have been very few applications of ML in PD research. These applications either focus on patient diagnosis salvatore2014machine or symptom and biomarker identification kubota2016machine; tsanas2012accurate. All of these applications have only considered unimodal data, no focus has been placed on leveraging the diverse multimodal data available. Therefore, PD is used as a case study for the remainder of this paper.

2 Dataset and Preprocessing

This study makes use of the brain SPECT images and corresponding DNA methylation (DNA-m) data from the PPMI database111Data can be accessed at, last accessed 15 August 2019.. SPECT images are 3-D images of the brain which are used to model dopamine-transporter functionality, while DNA-m data is the only epigenetic marker for which a detailed mechanism of mitotic inheritance has been described. This provides a lifetime record of an individual’s environmental exposures bird2002dna. The dataset consists of healthy and PD patients.

The SPECT images have the shape and were co-registered and normalised to the range . The DNA-m data has -value features, filtered according to zhou2017comprehensive, in the range . -values are a ratio of intensities between methylated and unmethylated cytosine-guanine dinucleotide (CpG) sites.

Feature selection, using the recursive feature selection algorithm with an XGBoost model, was used to address the curse of dimensionality associated with the DNA-m data (see Appendix A for details). This resulted in 441 highly predictive features. Upfront feature selection enabled interpretability for the DNA-m data by allowing a one-to-one mapping between attention weights and features, something that a neural network generated latent representation (see Section 3.1) would prevent.

3 Multi-Head Co-Attention Model Design

Figure 1 provides an overview of the MHCA model applied to the SPECT images, , and DNA-m data, . The specialist encoder networks generate the keys, , queries, , and values, for the attention mechanism. The MHCA mechanism generates the attention coefficients, and , from and , respectively. The coefficients are used to weight the queries and keys, producing and

. Residual connections are added to each of the weighted outputs to account for information loss with the attention mechanism

he2016deep. The weighted outputs are then added together, a concept adapted from VQA for co-attention mechanisms lu2016hierarchical, and passed through a pointwise linear layer with sigmoid activation to perform the final prediction. This highlights how the different data modes influence

the features extracted from one another instead of the features being extracted






Figure 1: Overview of the proposed MHCA model for two data modes, see text for symbol definitions. Note, the proposed model is generalisable to different data modes.

3.1 Encoders

The encoder subnetworks extract important and correlated features between data modes and manipulate the output to be of the format features embedding. For DNA-m the feature selection is done upfront so the encoder simply assigns an embedding dimension of to the input.

The SPECT encoder consists of blocks of a convolution layer followed by a max pooling layer with ReLU activation nair2010rectified. Each convolution layer consisted of , , and filters, respectively. This was followed by a convolution layer with

filters and no activation function as it is responsible for generating the SPECT features. The final output is of dimension

, which is collapsed into the format . The fact that the convolution and max pooling operators maintain the spatial relationships between their inputs and outputs makes it possible to upsample the attention weights back to the dimensions of the original inputs for interpretability.

3.2 Multi-Head Co-Attention Mechanism

The MHCA mechanism leverages the scaled dot-product attention from the Transformer network

vaswani2017attention. Equation 1 defines the similarity score between two features, , as a function of the dot product between the SPECT features and the DNA-m features , and the feature dimensionality of the DNA-m, . Where refers to the transpose operator.


The Transformer network vaswani2017attention

uses the similarity matrix to determine a single vector of attention coefficients from one language to the other, as translation is directed. Given that cross-modal learning is not directed, this research leverages the commutative property of the dot product operator to determine two vectors of attention coefficients. This exploits the fact that the distribution of attention weights will likely be different when looking from the keys to queries compared to the queries to keys. The attention coefficients for the keys and queries are defined in Equations 

2 and 3, respectively. Thereafter, the hidden space, , is obtained by taking the sum of the attended data modes, and , as defined in Equation 4. LayerNorm is the layer normalisation operator ba2016layer.


4 Experiments and Results

Experimentation is centred around classifying healthy patients from PD patients. The model was compared the unimodal XGBoost DNA-m model, the unimodal SPECT encoder network with a final linear layer (

SpectNet), a cross-modal model which concatenates independently extracted features together (ConcatNet), and the State-Of-The-Art (SOTA) 3-D CNN for PPMI PD classification choi2017refining. The dataset was split into 80 % train and 20 % holdout test sets. Three runs of 5-fold cross validation (CV) were performed for each experiment, where each fold is a distinct split in the dataset. data augmentation was used for all experiments222Augmentation achieved using multiaug library: A batch size of , Adam optimisation kingma2014adam with a learning rate of , and

training epochs were used. Table 

1 presents the accuracy, AUROC, and the number of parameters for each of the experiments, where available. The results show that the MHCA model was able to achieve SOTA results, improving on the previous SOTA by . The MHCA model also required fewer parameters than the traditional ConcatNet model. These results highlight the benefits of learning a joint similarity distribution in cross-modal learning.

Model DNA-m SPECT Accuracy (%) AUROC No. Parameters
XGBoost 95.08 0.917
SpectNet 88.07 2.54 0.879 0.01
3D-CNN from choi2017refining 96.00
ConcatNet 97.12 0.58 0.968 0.00
MHCA 98.35 0.58 0.976 0.00 2,020,883
Table 1:

Mean and standard deviation for 3 runs of 5-fold CV on the test set. 3D-CNN is the previous SOTA results taken directly from the literature as the results were not replicable


Qualitatively, Figure 2 provides slices from the different angles of the brain for a healthy and PD patient. Notably, the weight masks validate existing research booth2015role by highlighting that the primary differentiating factor in SPECT images is the dopaminergic cell loss that is associated with PD patients. This is evident as the model focuses equally on both putamens (white regions) in the healthy patient but only on the putamen without the apparent degradation in the PD patient.

(a) PD, top, .
(b) Healthy, top, .
(c) PD, back, .
(d) Healthy, back, .
(e) PD, side, .
(f) Healthy, side, .
Figure 2: Patient samples with weight mask overlay. Top: PD patient, bottom: healthy patient. The subcaptions are of the format: label, angle, volume slice number. Figures best observed in colour.

PD is commonly linked to environmental factors, hence DNA-m plays a fundamental role in understanding the gene-environment interactions miranda2017implications. To investigate these interactions the CpG sites that the attention weights placed the most emphasis on were selected (136 sites). The Bioconductor R package hansen2016bioconductor was used to obtain the set of 90 genes corresponding to the sites. DAVID, a bioinformatics tool, was used to determine the biological pathways relevant to the genes huang2008bioinformatics; huang2009systematic. This analysis identified the interferon-gamma-mediated (IFN-) pathway, alternative splicing, and transcriptional activator activity as the most important pathways. IFN-, with the highest p-value (), has featured in recent studies linking gene expression data and the PD phenotype mount2007involvement; barcia2011ifn; liscovitch2014differential. These studies all focus on gene expression patterns, whereas this study presents the novel finding that links IFN- to PD through DNA-m data. This finding is important as DNA-m patterns are considered more reliable and stable than gene expression patterns. It also suggests that changes in DNA-m, as a result of lifestyle, could influence the progression of PD. These results highlight how the interpretability can be used to not only support patient diagnosis but also facilitate exploratory analyses of data.

5 Conclusion

This research presents the MHCA model, a generalised cross-modal deep learning model capable of handling multimodal data with different biological scales. A focus is placed on modelling the complex relationships between different data modes by obtaining a joint similarity distribution through the use of an attention mechanism. PD forms the basis of the evaluation, exceeding the SOTA classification results by , while being more parameter efficient. The model also produced granular and informative interpretations, further linking novel biomarkers to the IFN- pathway. The results obtained highlight the value of attention in cross-modal learning. Future work will focus on scalability and memory efficiency to enable evaluation of higher-dimensional datasets.


Appendix A DNA-Methylation Preprocessing

This appendix provides additional information on the upfront feature selection process performed for the DNA-methylation data, which is necessary for reproducing the results obtained.

Feature selection is a powerful machine learning technique which aims to reduce the number of features in a dataset by selecting a subset of predictive features. Feature selection involves training a model, which has built-in feature importance, to obtain a ranking of the features in terms of their predictive power. Thereafter, the top features are selected based on a user-defined metric.

An XGBoost model is used in this study. XGBoost is a popular gradient boosting algorithm with built-in ensembling

Chen:2016:XST:2939672.2939785. The model is popular for the ability to handle data with few samples, making it suitable for the given application. The XGBoost model was combined with the Recursive Feature Elimination (RFE) algorithm to perform the final feature selection. RFE is a recursive algorithm that trains an XGBoost model on the set of available features. RFE then uses the feature importance from the XGBoost algorithm to remove a subset of least predictive features. It repeats this process until a user-defined number of features are selected. RFE is significantly more powerful than simply using the feature importance from a single XGBoost model as it accounts for the fact that the importance of features might change in the presence or absence of other features. The model was tasked on classifying Parkinson’s Disease patients from healthy patients using the PPMI dataset. Table A.1

defines the hyperparameters for the models, obtained performing a parameter sweep. The number of features to be selected by the RFE algorithm was set to 441 to ensure the dimensionality with the SPECT images is consistent.

Parameter Algorithm Value
learning_rate XGBoost 0.1
n_estimators XGBoost 100
max_depth XGBoost 6
min_child_weight XGBoost 2
gamma XGBoost 0
subsample XGBoost 0.8
colsample_bytree XGBoost 0.8
nthreads XGBoost 15
scale_pos_weight XGBoost 1
n_features_to_select RFE 441
step RFE 0.05
Table A.1: Hyperparameters for RFE XGBoost feature selection model.

In order to validate that reducing the set of features did not negatively impact the data’s predictive power, the XGBoost model was trained on the original dataset and on the reduced dataset. The dataset was split into 80 % train and 20 % holdout test sets. The results obtained can be seen in Table A.2. Table A.2 shows that the reduced set of features actually improved the results. This can be attributed to the unfiltered data suffering from the curse of dimensionality. The reduced model forms a baseline for future comparisons.

Number of features Accuracy (%) AUROC
765,373 67.21 0.477
441 95.08 0.917
Table A.2: Test set results for the XGBoost model before and after feature selection using the RFE algorithm.