Cardiac Complication Risk Profiling for Cancer Survivors via Multi-View Multi-Task Learning

09/25/2021 ∙ by Thai Hoang Pham, et al. ∙ The Ohio State University 0

Complication risk profiling is a key challenge in the healthcare domain due to the complex interaction between heterogeneous entities (e.g., visit, disease, medication) in clinical data. With the availability of real-world clinical data such as electronic health records and insurance claims, many deep learning methods are proposed for complication risk profiling. However, these existing methods face two open challenges. First, data heterogeneity relates to those methods leveraging clinical data from a single view only while the data can be considered from multiple views (e.g., sequence of clinical visits, set of clinical features). Second, generalized prediction relates to most of those methods focusing on single-task learning, whereas each complication onset is predicted independently, leading to suboptimal models. We propose a multi-view multi-task network (MuViTaNet) for predicting the onset of multiple complications to tackle these issues. In particular, MuViTaNet complements patient representation by using a multi-view encoder to effectively extract information by considering clinical data as both sequences of clinical visits and sets of clinical features. In addition, it leverages additional information from both related labeled and unlabeled datasets to generate more generalized representations by using a new multi-task learning scheme for making more accurate predictions. The experimental results show that MuViTaNet outperforms existing methods for profiling the development of cardiac complications in breast cancer survivors. Furthermore, thanks to its multi-view multi-task architecture, MuViTaNet also provides an effective mechanism for interpreting its predictions in multiple perspectives, thereby helping clinicians discover the underlying mechanism triggering the onset and for making better clinical treatments in real-world scenarios.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Cardiovascular diseases are widely known as the leading causes of mortality in breast cancer survivors [1, 2, 3, 4]. With the recent substantial improvement of breast cancer survival rates, predicting the onset of multiple cardiac complications has become a critical task for enhancing patients’ life quality. It is also a key to cost-effective disease management and prevention. However, this task is highly challenging because of the complex interactions between heterogeneous clinical entities. Effectively capturing these interactions may lead to more precise prediction and treatment for cancer survivors.

Over the past few decades, the rapid growth of real-world clinical data such as electronic health record (EHR) and insurance claims makes them valuable data sources used in data-driven (e.g., deep learning) systems for clinical risk prediction, especially complication risk profiling [5, 6, 7]. As shown in Figure 1, this data includes heterogeneous clinical entities (e.g., visit, disease, medication) and can be considered from multiple views (i.e., sequence of visits, set of features). However, most existing studies consider each clinical outcome prediction separately and extract information in clinical data from a single view, thereby, making them not well-suited for complication risk profiling and raising two challenges.

Fig. 1: Visit-view (sequence of clinical visits (rows)) and feature-view (set of clinical codes (columns)) of clinical data.

C1. Clinical data is highly complex due to its heterogeneous and hierarchical structure. Thus, encoding patient records from single-view cannot provide comprehensive representations of these patients, and thereby cannot achieve superior prediction performance. In particular, by considering patient records as sequences of visits, previous works only learn the dependencies among clinical visits but cannot explicitly capture dynamic patterns of clinical features and their interaction at the global (i.e., sequence) level.

C2. Treating each complication onset prediction independently can lead to suboptimal models, especially in limited datasets. This is because it cannot capture the dependencies among complications that are manifestations caused by their common underlying condition. Moreover, this approach cannot exploit meaningful clinical patterns from unlabeled data, which is much easier to collect and can be used to improve prediction performance when labeled data is limited.

To tackle the two aforementioned challenges, we propose a new neural network-based framework named

Multi-View Multi-Task Network (MuViTaNet) for cardiac complication risk profiling. This proposed model consists of a multi-view encoder (dealing with C1) and a novel multi-task learning (MTL) scheme (dealing with C2). In particular, the multi-view encoder

includes visit-view and feature-view encoders that capture information from clinical visits and features simultaneously. The visit-view encoder considers a patient record as the sequence of clinical visits and captures the temporal relation among visits by Gated Recurrent Unit (GRU) network. The feature-view encoder considers the patient record as the set of temporal medical features, and then leverages convolutional neural networks (CNN) to extract temporal patterns from these features separately. Then, the max-pooling operation is applied to extract the most significant signals from temporal sequences. The

MTL scheme utilizes an attention mechanism to learn complication-specific representation from shared information generated by the multi-view encoder. This scheme allows MuViTaNet to exploit additional information from related complications and unlabeled data to generate more generalized representations for the patient, which enables more accurate predictions. Moreover, by leveraging the attention mechanism associated multi-view encoder, the proposed model provides an efficient way to interpret its predictions from multiple perspectives, thereby helping clinicians discover the underlying mechanism triggering the onset and making better clinical treatments. We demonstrate that the proposed model significantly outperforms current state-of-the-art approaches for complication risk profiling task using multiple datasets derived from the insurance claim database. In summary, our contributions include the following:

  • We design a multi-view multi-task neural network architecture (MuViTaNet111Code is available at https://github.com/pth1993/MuViTaNet) that accurately predicts multiple complication onsets and efficiently interprets its predictions.

  • We develop a multi-view encoder to explicitly capture dependencies among clinical visits and clinical features from multiple views of clinical data.

  • We also introduce a new MTL scheme that utilizes a complication-specific attention mechanism on top of the multi-view encoder to capture additional clinical information from related complications and unlabeled datasets.

  • Finally, we conduct a comprehensive empirical study to demonstrate the effectiveness of MuViTaNet in terms of both prediction performance and interpretability compared to a wide range of previous approaches for cardiac complication risk profiling.

The remainder of the paper is organized as follows. Section II summarizes related works on clinical risk prediction in general and in particular, complication risk profiling. Section III describes the technical details of the proposed model (MuViTaNet). Section IV presents experimental results and discussions. Finally, Section V concludes the paper.

Ii Related Works

In this section, we briefly review existing works related to our study including patient representation learning and MTL for clinical risk prediction, as well as complication risk profiling.

Patient representation learning.

The abundance of real-world data in recent years creates an unprecedented opportunity to apply machine learning and data mining methods for clinical risk predictions. With the advancement of deep learning theory and the acceleration in computational technologies, neural network-based architectures can significantly improve prediction performance due to their ability to extract rich representations from data. Because of the temporal nature of clinical data, most existing methods rely on recurrent neural network architectures to learn patient representations, which are then used to make predictions for future clinical events (e.g., diagnosis, mortality, readmission, etc.) 

[5, 6, 7, 8, 9]. These works focused on designing attention mechanisms to capture dependencies among clinical visits [5, 8, 9] and time-aware mechanisms to incorporate temporal information [6, 10, 11] into patient representation for making better predictions. Nonetheless, these models cannot explicitly capture the relationships among clinical features. Instead of considering EHR data as sequences of clinical visits, Concare [12] treats the record as the set of clinical features and extracts dynamic patterns of these features separately. Then the predictions are made by aggregating representations of all clinical features. However, all the existing methods only extract information from a single view of clinical data which makes the learned patient representations suboptimal. In contrast, we propose a multi-view model for capturing information from multiple views of clinical data simultaneously.

Multi-task learning. Multi-task learning (MTL) has been used widely across many applications of machine learning and data mining. By sharing information among related tasks, the prediction model can generalize better. In healthcare domain, some existing works applied MTL techniques to leverage information from related tasks to improve model performance in clinical risk prediction. In particular, both classical machine learning [13, 14, 15] and deep learning models [16, 17, 18] are formulated as MTL frameworks and are applied on a wide range of healthcare applications including disease progression modeling [13], mortality prediction [16], disease onset prediction [17], and diagnosis classification [18].

Complication risk profiling.

Mitigating the risk of complications is crucial for many disease management programs. Despite its importance, there have not been many existing methods designed for this task. Unlike a single clinical risk prediction task, complication risk profiling requires multiple predictions for onset of complications. Thus, capturing relationships among related complications is crucial to achieving good prediction performances. Some methods have been proposed to predict the onset of complications of some diseases and clinical procedures. For example, multi-task logistic regression has been used to predict complication risks for diabetes care 

[14, 19]. Besides linear models, the deep learning method is also used to predict complications of this chronic disease [20] but this work considers each complication independently. For breast cancer survivors, relationships between cardiac complications and cancer were also investigated [21, 4, 3] to show the correlation between these two diseases.

Iii Methodology

(a)
(b)
(c)
Fig. 2: General schemes for learning from clinical data. (a) single-view single-task learning, (b) single-view multi-task learning, (c) multi-view multi-task learning. Our proposed model belongs to multi-view multi-task learning with the multi-view encoder (i.e., visit-view and feature-view) and the task-specific attention mechanisms and decoders for both labeled and unlabeled datasets.
Fig. 3: The overall architecture of MuViTaNet. The proposed framework consists of four main components: feature-view encoder, visit-view encoder, task-specific attention, and task-specific decoder. Given a patient record, MuViTaNet first extracts information from clinical visits and features by looking at the record in two different ways: sequence of clinical visits and set of clinical features. Then the shared representation learned by these two encoders is put into the task-specific attention to learn the task-specific representation. Finally, the clinical predictions are generated by the task-specific decoders. Note that the figure only shows the task-specific attention for one prediction task for simplicity.

In this section, we first give brief introduction about patient records, complication risk profiling task and the corresponding notations. Then, we present our proposed model MuViTaNet.

Iii-a Definitions and Basic Notations

Definitions and notations used in this study are shown in the following paragraphs and are summarized in Table I.

Patient Records. The heterogeneous and hierarchical structure of a patient record is defined as follows.

  • Definition 1 (Clinical Code). is the set of unique clinical codes including diagnosis, procedure, and medication codes with is the number of these unique codes. Each code

    can be represented by binary vector

    where element of this vector is 1 and other elements are 0.

  • Definition 2 (Clinical Visit). An visit is a hospital stay from admission to discharge. Each visit is a tuple of , where with set of indexes and is the timestamp of the visit. can be represented by binary vector where the element is 1 if contains the code . Besides vector representation, can also be expressed as matrix where row of this matrix is the binary vector of code .

  • Definition 3: (Patient Record). The patient record is a sequence of visits where is the number of visits. Like clinical visit representation, can be represented at the two different granularities. At visit-level, can be represented as a binary matrix where row of this matrix is binary vector of visit . At feature-level, can be represented as the sequence of matrices .

  • Definition 4: (Demographic Information). Besides clinical information, a patient record can have demographic information about the patient such as age, gender, region, etc. It can be represented by binary vector , where is the number of demographic attributes.

Clinical Risk Profiling. The aim of this task is to find a set of functions that predicts the onset of complications from patient record , where is the number of complications. In MTL setting, generally have some shared parameters to learn shared information from related tasks for better predictions.

Notation Description
Set of clinical codes/features
A patient record
clinical codes in set C
vector representation of code
clinical visit in
set of clinical codes in visit
timestamp of visit
vector representation of visit
matrix representation of
visit-level representation of
feature-level representation of
vector representation of demographics
attention weights of codes in visit
task-specific attention weights for features
task-specific attention weights for visits
temporal encoding vector of visit
representation learned by visit-view encoder
patient representation
representation learned by feature-view encoder
visit-view task-specific representation for task
feature-view task-specific representation for task
task-specific representation for task
ground-truth output for task
predicted output for task
TABLE I: Notation definition
Input: Datasets , set of clinical codes , batch sizes ,
Output: Trained model parameters
1 Randomly initialize ;
2 Calculate sampling rate for each dataset ( if , otherwise);
3 for epoch to E do
4       repeat
5             Select dataset ;
6             Initialize loss ;
7             Select sample batch from dataset ;
8             for patient in batch  do
9                   ;
10                   Obtain feature-view representation from using Eq. (1), (2);
11                   Obtain visit-view representation and patient representation from using Eq. (3)-(11);
12                   Calculate task-specific attention weights from , using Eq. (12);
13                   Obtain task-specific representations using Eq. (13);
14                   if  then
15                         Calculate prediction using Eq. (14);
16                         Calculate BCE loss using Eq. (16);
17                   else
18                         Project multi-view representations to unit hypersphere using Eq. (15);
19                         Calculate CL loss using Eq. (17);
20                        
21                   ;
22                  
23             end for
24            Update parameters using gradient of ;
25             ;
26            
27      until ;
28 end for
Algorithm 1 Training procedure for MuViTaNet

Iii-B Proposed Model

Overview Architecture. This section presents our proposed multi-view multi-task network (MuViTaNet) for predicting onset of multiple complications from patient records. MuViTaNet is designed to explicitly capture the dependencies among clinical visits and clinical features from patient records. It leverages additional information from both related labeled and unlabeled data in MTL to achieve accurate predictions and efficient interpretation. In particular, MuViTaNet consists of four main components as follows. (1) Feature-view Encoder. This component considers a patient record as a set of temporal clinical features and then encodes information of each feature separately. (2) Visit-view Encoder. This component formulates a patient record as a sequence of visits and then learns a representation for each visit in the sequential context. Specifically, this component is designed as a hierarchical model that exploits patient records in the two levels, including feature-level and visit-level. (3) Task-specific Attention. After learning the shared representation from feature-view and visit-view encoders, an attention mechanism is employed to extract task-specific representation for each task from the shared representation. (4) Task-specific Decoder. The task-specific representations are fed into the corresponding task-specific decoders to predict clinical outcomes for patients in complication datasets and to project representations to unit hypersphere for patients in unlabeled dataset. Figure 3 shows the overview architecture of MuViTaNet and technical details of its components are presented as follows.

Feature-view Encoder. This component treats patient data as a set of clinical codes which are represented by the set of temporal sequences (i.e., columns of matrix ). In particular, given clinical code , its temporal data can be represented by a binary vector which is column of . Then, one-dimensional convolutional neural networks (Conv1d) and max-pooling (MaxPool) operation are employed to extract temporal patterns from each clinical code separately. In particular, Conv1d with kernel size (i.e., in our setting) takes as inputs the sub-sequences of length k from vector to learn the representation of code as follows.

(1)

where are the output of and is the number of filters used in convolution operations. Next, the row-wise max-pooling is applied to to generate vector representation for clinical code .

(2)

Note that the weights of Conv1d are not shared between clinical codes. The output of feature-view encoder is matrix .

Visit-view Encoder. This component formulates patient data as a sequence of visits in which each visit can be seen as a set of clinical codes. Due to the hierarchical characteristic of this data structure, the visit-view encoder is also designed hierarchically to capture information at different levels. Given visit , we represent this visit by matrix which is element of the sequence . Because different clinical codes associated with the same visit can have disparate impacts, instead of treating these clinical codes uniformly when aggregating them to represent the visit, the location attention mechanism is employed to learn the contributions of these clinical codes to their visit representation. In particular, given a binary representation of code

, 1-layer feed-forward neural network is applied to learn the dense representation from sparse vector of this clinical code as follows.

(3)

where is the learned weight matrix of clinical codes,

is the bias vector, and ReLU is rectified linear unit activation function. Then the 2-layer feed-forward neural network

with activation function is used to generate the attention score for this clinical code as follows.

(4)

The attention vector which represents the contributions of clinical codes in visit

is fed into the softmax layer to get the normalized vector

.

(5)

Then, the representation of visit are computed as the weighted average of its clinical codes.

(6)

where denotes the visit’s representation. To generate personalized representation for each visit, demographic information including age and region is incorporated into every clinical visit as follows.

(7)

where is the concatenation operation and is the weight matrix mapping concatenated vectors to the original embedding space. Besides clinical codes, each visit is also associated with its timestamp. In order to capture the elapsed time between visits, we add the temporal encoding vector to each visit as follows.

(8)

where is the temporal encoding vector whose design is inspired by the positional encoding used in Transformer architecture [22]. In particular, it is computed by trigonometric functions as follows.

(9)

where . From Equation (9), we can see that temporal embedding encodes similar time intervals into similar vectors in embedding space.

To generate the sequential representations for visits in the sequential context, we put the independent representations for visits learned from previous steps into the bidirectional GRU layer. Specifically, the sequential representation for these visits is computed as follows.

(10)

where . Then, the patient representation is computed based on the last visit in the visit sequence.

(11)

In summary, the outputs of the visit-view encoder include the sequential representations of clinical visits and the patient representation .

Task-specific Attention. Given the shared representations generated by feature-view and visit-view encoders, attention mechanisms are employed to generate the task-specific representations for the patient. Specifically, the attention weights of clinical features and visits for task are computed as follows.

(12)

where are 2-layer feed-forward neural networks with activation function that compute the weights of clinical features and visits from their representations. Then, we obtain the task-specific representation for task as follows.

(13)

Task-specific Decoder. For a patient in labeled dataset (i.e., complication dataset), the 2-layer feed forward neural network with

activation function at the last layer is employed to predict the probability of complication onset for this patient.

(14)

For a patient in unlabeled dataset, the 2-layer feed forward neural network with normalization operation () is used to project the feature-view and visit-view representations of this patient on the unit hypersphere.

(15)

Optimization. To train MuViTaNet in MTL setting, we follow the alternating training strategy [23] in which each task is selected randomly and then is optimized for a fixed number of parameter updates before switching to other tasks. In our setting, different tasks have datasets of different sizes, so we select a task to optimize with probability , where and are the dataset and batch size for task, and is the number of complication datasets.

For labeled datasets, the binary cross-entropy (BCE) loss function is used to optimize the prediction based on ground-truth labels. Specifically, for

task with dataset , the loss function for this task is computed as follows.

(16)

where and are the ground-truth and predicted outputs for task respectively. For unlabeled dataset, we leverage the contrastive (CL) loss function [24] to pull together the normalized representations of feature-view and visit-view of the same patient and to push apart these representations from representations of other patients.

(17)

where in which .

Iv Experiments

complication Description ICD-10 Codes #subjects
Atrial Fibrillation
An irregular, often rapid heart
rate that commonly causes
poor blood flow
I48 322
Coronary Artery
Disease
Damage or disease in the
heart’s major blood vessels
I20-I25 769
Heart failure
A chronic condition in which
the heart doesn’t pump blood
as well as it should
I11, I13
I42, I50
1124
Hypertension
A condition in which the force
of the blood against the artery
walls is too high
I10, I16 6787
Peripheral Arterial
Disease
A circulatory condition in
which narrowed blood vessels
reduce blood flow to the limbs
I70 340
Stroke
Damage to the brain from
interruption of its blood supply
I60-I69 592
TABLE II: Cardiac complications in female breast cancer cohort and their corresponding ICD codes and numbers of positive instances.
Method AF CAD HF Hypertension PAD Stroke Average

Single-task

Classical LR
RF
Recurrent-
based
GRU
Bi-GRU
Time-aware T-LSTM
Attention-
based
Dipole
RETAIN
Transformer
LSAN

Multi-task

Recurrent-
based
GRU
Bi-GRU
Time-aware T-LSTM
Attention-
based
Dipole
RETAIN
Transformer
LSAN
Ours MuViTaNet
TABLE III:

Comparison of prediction performance measured by AU-ROC scores on six complication risk profiling tasks. We report the average AU-ROC scores and their corresponding standard deviation. (AF: Atrial Fibrillation, CAD: Coronary Artery Disease, HF: Heart Failure, PAD: Peripheral Arterial Disease).

In this section, we evaluate the performances of MuViTaNet on six real-world insurance claim datasets and compare its results with state-of-the-art clinical risk prediction models to demonstrate the effectiveness of our method. Besides achieving accurate prediction, we also show the robustness of MuViTaNet in terms of interpretability.

Iv-a Datasets

Breast cancer cohort construction. We extract clinical records of female breast cancer patients from the MarketScan Commercial Claims and Encounter (CCAE) database provided by Truven Health222https://truvenhealth.com/markets/life-sciences/products/data-tools/marketscan-databases to construct cardiac complication risk profiling datasets. According to the previous work [19], the records from 2012 to 2017 of de-identified patients are selected based on the following criteria.

  • Ages of the selected patients are from 18 to 65 at the initial diagnosis of breast cancer.

  • The selected patients have at least six months of records and ten clinical visits before being diagnosed with breast cancer.

  • There is no cardiac complication diagnosis until the initial diagnosis of breast cancer of the selected patients.

Cardiac complication datasets construction. After construing the breast cancer cohort, we create a distinct dataset for each cardiac complication onset prediction task. In our setting, we focus on profiling the risk of developing cardiac complications in a six-month window after the initial diagnosis of breast cancer, and the positive instances are defined as patients who have cardiac complications in this window. Following previous clinical research [3, 4], we identify six cardiac complications including atrial fibrillation (AF), coronary artery disease (CAD), heart failure (HF), hypertension, peripheral arterial disease (PAD), and stroke. Descriptions, ICD codes, and the corresponding numbers of positive instances of these complications are shown in Table II. The negative instances are randomly selected from the breast cancer cohort with a ratio of 3:1 compared to positive instances.

Unlabeled dataset construction. The negative patients that are not selected for complication datasets are used to construct a dataset for contrastive learning. MuViTaNet leverages this dataset as additional information to improve the prediction performances of complication onset prediction tasks.

Feature selection. We use the following information to profile cardiac complications for breast cancer patients.

  • Demographics including age and region information. We cluster patients into three age groups (i.e., ) and five region groups.

  • Clinical codes including diagnosis, procedure, and medication codes. For diagnosis codes, all ICD-9 codes are converted to ICD-10 codes. To alleviate data sparsity, we group all diagnosis and procedure codes based on their first three characters and remove codes that appear in less than 200 patients. For medication codes, we group them by their therapeutic classes. This preprocessing step results in 1188 features.

Atrial Fibrillation Coronary Artery Disease Heart Failure
Nonrheumatic mitral valve disorders (I34) Other cardiac arrhythmias (I49) Other cardiac arrhythmias (I49)
Other cardiac arrhythmias (I49) Nonrheumatic mitral valve disorders (I34) Varicose veins of lower extremities (I83)
Complications and ill-defined heart disease (I51) Varicose veins of lower extremities (I83) Diseases of capillaries (I78)
Paroxysmal tachycardia (I47) Diseases of capillaries (I78) Other disorders of veins (I87)
Diseases of capillaries (I78) Type 2 diabetes mellitus (E11) Embolism and thrombosis (I82)
Embolism and thrombosis (I82) Other peripheral vascular diseases (I73) Type 2 diabetes mellitus (E11)
Other conduction disorders (I45) Embolism and thrombosis (I82) Complications and ill-defined heart disease (I51)
Varicose veins of lower extremities (I83) Hypotension (I95) Nonrheumatic mitral valve disorders (I34)
Nonrheumatic aortic valve disorders (I35) Other disorders of veins (I87) Other peripheral vascular diseases (I73)
Other disorders of veins (I87) Angina pectoris (I20) Overweight and obesity (E66)
Hypertension Peripheral Arterial Disease Stroke
Other cardiac arrhythmias (I49) Other cardiac arrhythmias (I49) Other cardiac arrhythmias (I49)
Abnormal blood-pressure reading, without diagnosis (R03) Varicose veins of lower extremities (I83) Nonrheumatic mitral valve disorders (I34)
Type 2 diabetes mellitus (E11) Diseases of capillaries (I78) Varicose veins of lower extremities (I83)
Nonrheumatic mitral valve disorders (I34) Nonrheumatic mitral valve disorders (I34) Other peripheral vascular diseases (I73)
Varicose veins of lower extremities (I83) Other disorders of veins (I87) Embolism and thrombosis (I82)
Overweight and obesity (E66) Nonspecific lymphadenitis (I88) Type 2 diabetes mellitus (E11)
Diseases of capillaries (I78) Other peripheral vascular diseases (I73) Other disorders of veins (I87)
Other peripheral vascular diseases (I73) Embolism and thrombosis (I82) Hypotension (I95)
Other disorders of veins (I87) Other noninfective disorders of lymphatic vessels (I89) Pain in throat and chest (R07)
Pain in throat and chest (R07) Type 2 diabetes mellitus (E11) Complications and ill-defined heart disease (I51)
TABLE IV: Top 10 most important clinical features (i.e., with the highest attention weights) for each cardiac complication as identified by MuViTaNet.

Iv-B Experimental Setup

Baseline Models. To validate the performance of the proposed model for cardiac complication risk profiling task, we compare it with several state-of-the-art models. Based on their architectures, these models are categorized into four main groups including classical model, recurrent-based model, attention-based model, and time-aware model. The details of these models are presented as follows.

  • Logistic Regression (LR). A classical model used in binary classification. To deal with insurance claim data, a patient record is converted to the count vector whose element is the frequency of clinical code in that record, and is then fed into LR.

  • Random Forest (RF) [25].

    A classical ensemble model whose prediction is the average computed from predictions of a number of decision tree classifiers. Inputs for RF are similar to LR.

  • Gated Recurrent Unit (GRU) [26]. A variant of recurrent neural network (RNN) that uses gating mechanism.

  • Bidirectional GRU (Bi-GRU) [20]. An improved version of GRU by employing an additional GRU model to learn the sequence data in reverse order.

  • Dipole [5]. An attention-based model that utilizes attention mechanism over the sequence generated by Bi-GRU to learn the dependencies between visits.

  • RETAIN [8]. An attention-based model that first employs a reverse RNN to process clinical records in reverse order to mimic physicians’ decisions. Then two attention modules are used to identify significant visits and variables.

  • T-LSTM [6]. A time-aware model designed for handling irregularity visits in clinical records. The memory cell of LSTM is modified to capture time intervals between two consecutive visits.

  • Transformer [22]. A fully attention-based model that uses multi-head attention mechanisms to learn the dependencies among elements in sequential data.

  • LSAN [27]. An attention-based model that uses Transformer to capture global information and CNN to capture local information.

  • MTL Models: We develop the MTL version for each of the aforementioned neural network-based models by employing task-specific attention and decoder over the output generated by these models.

  • : A variant of MuViTaNet by removing the visit-view encoder.

  • : A variant of MuViTaNet by removing the feature-view encoder.

  • : A variant of MuViTaNet by removing the task-specific attention and decoder for single-task learning (STL) setting.

  • : A variant of MuViTaNet trained with labeled datasets only.

Implementation Details.

All neural network-based architectures are implemented by PyTorch

333https://pytorch.org/. For classical models including LR and RF, we use their Python implementations from Scikit-Learn [28]. We use ADAM algorithm [29] to optimize the prediction performances for neural network-based models. The batch size is set as for labeled datasets and for unlabeled dataset, and the initial learning rate is .

Evaluation Metric. We conduct experiments under 5-fold cross-validation setting. instances from the training set are used to construct the validation set, and the results on the testing set are determined based on the best results on the validation set. The area under the receiver operating characteristic (AU-ROC) is used to measure the performances of prediction models for cardiac complication risk profiling.

Iv-C Results

Models Multi-view Multi-task AU-ROC
F V L U
MuViTaNet-task-specific
MuViTaNet-feature-view
MuViTaNet-visit-view
MuViTaNet-unlabeled
MuViTaNet
TABLE V: Average performances of MuViTaNet variants over 6 complication datasets (F: Feature-view, V: Visit-view, L: Labeled, U: Unlabeled).

We conduct experiments to answer the following questions.

  • Q1. How accurate is MuViTaNet for cardiac complication risk profiling task comparing to previous works?

  • Q2. How each component of MuViTaNet contributes to its prediction performance?

  • Q3. How to effectively interpret the predictions made by MuViTaNet?

Cardiac complication risk profiling. As shown in Tables III, MuViTaNet achieves the best performances compared to other baselines for cardiac complication risk profiling task measured by AU-ROC score. Generally, it achieves an average (i.e., over six datasets) AU-ROC score of , which is better than the best previous method. Looking into each complication dataset, we also observe that MuViTaNet consistently outperforms other methods in terms of AU-ROC score. Such improvements indicate the advantage of MuViTaNet by using (1) multi-view encoder to extract comprehensive information and (2) MTL scheme to leverage information from both related labeled and unlabeled datasets to improve its prediction performance.

For baseline methods, we can observe that formulating complication risk profiling as MTL significantly improves the prediction performances of these methods. The improvements are more noteworthy for small datasets, including AF (), CAD (), PAD (), and stroke (). These results demonstrate the importance of leveraging task-related information for predicting the onset of complications. We also see that GRU-based models achieve slightly improved performances compared to other neural network models. For STL setting, the averaged prediction performances of deep learning models are on par with RF and are much better than LR. To investigate more, we zoom into the prediction performance for each dataset and observe that RF outperforms deep learning models for AF, CAD, PAD, and stroke datasets whose sizes are relatively small compared to HF and hypertension datasets. This result is reasonable because deep learning methods generally require large training data to achieve good prediction performance.

Ablation study. To investigate the contribution of each component in MuViTaNet, we conduct an ablation study by comparing MuViTaNet with its simpler variants including MuViTaNet-visit-view, MuViTaNet-feature-view, MuViTaNet-task-specific, and MuViTaNet-unlabeled on the six aforementioned datasets. The AU-ROC scores of these models are shown in Table V. We can observe that encoding clinical data solely by a single-view encoder is not as good as a multi-view encoder. AU-ROC score of MuViTaNet decreases to (resp. ) when only using visit-view (resp. feature-view) encoder. This result demonstrates the necessity of aggregating information from multiple views. The performance of MuViTaNet also drops significantly when we remove the task-specific attention mechanism and decoder, which further confirms the importance of formulating complication risk profiling task as MTL with both labeled and unlabeled datasets.

Model interpretability. The deployment of data-driven systems to healthcare applicants in real-world requires not only models with good prediction performance but also efficient mechanisms to interpret the automated decision to clinicians. By leveraging the multi-view multi-task architecture, our proposed model can interpret the prediction for each complication in multiple perspectives, thereby helping clinicians understand which clinical entities contribute most to the prediction.

Positive patient from heart failure dataset
Visits Visit 9 (0.11) Visit 3 (0.11) Visit 11 (0.10) Visit 8 (0.09) Visit 6 (0.09)
Features 796.2 (0.26) 250.00 (0.25) 278.00 (0.12) 882.0 (0.05) 19083 (0.04)
Negative patient from hypertension dataset
Visits Visit 9 (0.11) Visit 11 (0.11) Visit 7 (0.10) Visit 4 (0.10) Visit 3 (0.09)
Features M-174 (0.56) 250.00 (0.22) S0612 (0.13) J3010 (0.02) 82043 (0.02)
TABLE VI: Top 5 most important clinical visits and features (i.e., with the highest attention weights) for the 2 patients illustrated in Figure 4.

To characterize cardiac complications, we find the most important features for each of these cardiac complications by averaging the feature-view attention weights over all positive patients for clinical features in each complication dataset. Due to the varied number of features across patients, we rescale attention weights by multiplying them with the number of features appeared in the corresponding records before averaging. Then top-10 clinical features for 6 cardiac complications are shown in Table IV. We observe that these complications share many common features such as I34 (Nonrheumatic mitral valve disorders), I49 (Other cardiac arrhythmias), etc. This result is reasonable because all of these complications belong to cardiovascular disease class. Moreover, many important features determined by our model are known to be clinically associated with the corresponding complications. For example, patients with type II diabetes are two to four times more likely to develop heart diseases than someone without diabetes [30]. Obesity is another major known risk factor for heart failure and hypertension patients [31, 32]. Angina pectoris is the type of chest pain caused by reduced blood flow to the heart and is considered as a symptom of coronary artery disease [33].

(a) Positive patient from heart failure dataset
(b) Negative patient from hypertension dataset
Fig. 4: Visualization of 2 patient records (i.e., positive patient from heart failure dataset and negative patient from hypertension dataset) from breast cancer cohort. We only show important visits in clinical records due to limited space.

Case study. To further investigate the interpretability of MuViTaNet, we look at two case studies to visualize the learned attention weights for finding risk factors of each complication. The case studies include a positive patient from heart failure dataset and a negative patient from hypertension dataset. Their clinical records are illustrated in Figure 4. The most important visits and features determined by their associated attention weights from visit-view and feature-view task-specific attention components are shown in Table VI. For the positive patient (Figure 3(a)), the predicted probability for heart failure onset is . As shown in Table VI, the visit-view attention focuses more on visits 3 and 9, which include clinical codes 250.00 (Type II diabetes mellitus) and 278.00 (Obesity) and these codes are also determined as the most important features by the feature-view attention. This result is also consistent with clinical research in which type II diabetes mellitus and obesity have been shown as the common risk factors for heart failure disease [30, 32], thereby demonstrating the effectiveness of MuViTaNet in capturing the correlation between risk factors and corresponding diseases. To further investigate the robustness of our model, we remove important visits and features indicating heart failure’s risk factors from the patient record and predict the probability of heart failure onset based on the modified records for capturing the changes in model output. Figure 3(a) shows that the predicted score decreases to and when removing visits (3 and 9) and codes (250.00, 278.00, and 796.2) respectively. Thus, MuViTaNet is capable to focus on clinical-related visits and features when predicting onset of complications.

Figure 3(b) shows a clinical record of the negative patient who has type II diabetes mellitus but is also treated by M-174 (Metformin). Tables VI indicates that MuViTaNet pays more attention on M-174 and 250.00 when predicting onset of hypertension. To verify whether our model can capture the relationship between disease and treatment, we remove these codes from the patient record as we did for the positive patient. Figure 3(b) shows that the predicted probability increases from to when removing Metformin (diabetes medication) and decreases to when removing code 250.00 (diabetes). This result indicates that MuViTaNet considers the impact of both disease and treatment on complication development when making predictions.

V Conclusions

Complication risk profiling is a crucial problem in healthcare prediction domain. In this paper, we propose a novel multi-view multi-task network (MuViTaNet) that leverages clinical data to profile multiple complications for patients. To tackle the issues of existing methods, MuViTaNet considers the record as the sequence of clinical visits and the set of clinical features, and then employs the multi-view encoder to effectively extract meaningful information from both feature-view and visit-view of the patient record. Due to the relatedness among different complications, we organize MuViTaNet as the MTL architecture in which the shared representation learned from the multi-view encoder is put into multiple task-specific attention components to learn task-specific representations for patients in both labeled and unlabeled datasets. Finally, the predicted probability for each complication onset is generated from the task-specific representation by the corresponding decoder. We evaluate the prediction performance of MuViTaNet on the insurance claim database which consists of 6 cardiac complication datasets for breast cancer survivors. The experimental results demonstrate that our proposed model outperforms other state-of-the-art models for the complication risk profiling task. More importantly, MuViTaNet provides an efficient mechanism to interpret their prediction from multiple perspectives, thereby helping clinicians to make better decisions in real-world scenarios.

Acknowledgment

This work was funded in part by the National Science Foundation under award number CBET-2037398.

References

  • [1] C. Schairer, P. J. Mink, L. Carroll, and S. S. Devesa, “Probabilities of death from breast cancer and other causes among female breast cancer patients,” Journal of the National Cancer Institute, vol. 96, no. 17, 2004.
  • [2] J. L. Patnaik, T. Byers, C. DiGuiseppi, D. Dabelea, and T. D. Denberg, “Cardiovascular disease competes with breast cancer as the leading cause of death for older females diagnosed with breast cancer: a retrospective cohort study,” Breast Cancer Research, vol. 13, no. 3, 2011.
  • [3] H. Abdel-Qadir, P. Thavendiranathan, K. Fung, E. Amir, P. C. Austin, G. S. Anderson, and D. S. Lee, “Association of early-stage breast cancer and subsequent chemotherapy with risk of atrial fibrillation,” JAMA network open, vol. 2, no. 9, 2019.
  • [4] H. Strongman, S. Gadd, A. Matthews, K. E. Mansfield, S. Stanway, A. R. Lyon, I. dos Santos-Silva, L. Smeeth, and K. Bhaskaran, “Medium and long-term risks of specific cardiovascular diseases in survivors of 20 adult cancers: a population-based cohort study using multiple linked uk electronic health records databases,” The Lancet, vol. 394, no. 10203, 2019.
  • [5]

    F. Ma, R. Chitta, J. Zhou, Q. You, T. Sun, and J. Gao, “Dipole: Diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks,” in

    KDD’17, 2017.
  • [6] I. M. Baytas, C. Xiao, X. Zhang, F. Wang, A. K. Jain, and J. Zhou, “Patient subtyping via time-aware lstm networks,” in KDD’17, 2017.
  • [7] J. Gao, C. Xiao, Y. Wang, W. Tang, L. M. Glass, and J. Sun, “Stagenet: Stage-aware neural networks for health risk prediction,” in WWW’20, 2020.
  • [8] E. Choi, M. T. Bahadori, J. A. Kulas, A. Schuetz, W. F. Stewart, and J. Sun, “Retain: An interpretable predictive model for healthcare using reverse time attention mechanism,” in NIPS’16, 2016.
  • [9]

    H. Song, D. Rajan, J. Thiagarajan, and A. Spanias, “Attend and diagnose: Clinical time series analysis using attention models,” in

    AAAI’18, vol. 32, no. 1, 2018.
  • [10] T. Bai, S. Zhang, B. L. Egleston, and S. Vucetic, “Interpretable representation learning for healthcare via capturing disease progression through time,” in KDD’18, 2018.
  • [11] B. C. Kwon, M.-J. Choi, J. T. Kim, E. Choi, Y. B. Kim, S. Kwon, J. Sun, and J. Choo, “Retainvis: Visual analytics with interpretable and interactive recurrent neural networks on electronic medical records,” IEEE transactions on visualization and computer graphics, vol. 25, no. 1, 2018.
  • [12] L. Ma, C. Zhang, Y. Wang, W. Ruan, J. Wang, W. Tang, X. Ma, X. Gao, and J. Gao, “Concare: Personalized clinical feature embedding via capturing the healthcare context,” in AAAI’20, vol. 34, no. 01, 2020.
  • [13] J. Zhou, L. Yuan, J. Liu, and J. Ye, “A multi-task learning formulation for predicting disease progression,” in KDD’11, 2011.
  • [14] B. Liu, Y. Li, Z. Sun, S. Ghosh, and K. Ng, “Early prediction of diabetes complications from electronic health records: A multi-task survival analysis approach,” in AAAI’18, vol. 32, no. 1, 2018.
  • [15] J. Wiens, J. Guttag, and E. Horvitz, “Patient risk stratification with time-varying parameters: a multitask learning approach,” The Journal of Machine Learning Research, vol. 17, no. 1, 2016.
  • [16] N. Nori, H. Kashima, K. Yamashita, H. Ikai, and Y. Imanaka, “Simultaneous modeling of multiple diseases for mortality prediction in acute hospital care,” in KDD’15, 2015.
  • [17] N. Razavian, J. Marcus, and D. Sontag, “Multi-task prediction of disease onsets from longitudinal laboratory tests,” in MLHC’16.   PMLR, 2016.
  • [18] Z. C. Lipton, D. C. Kale, C. Elkan, and R. Wetzel, “Learning to diagnose with LSTM recurrent neural networks,” in ICLR’16, 2016.
  • [19] B. Liu, Y. Li, S. Ghosh, Z. Sun, K. Ng, and J. Hu, “Complication risk profiling in diabetes care: A bayesian multi-task and feature relationship learning approach,” IEEE Transactions on Knowledge and Data Engineering, vol. 32, no. 7, 2019.
  • [20] B. Ljubic, A. A. Hai, M. Stanojevic, W. Diaz, D. Polimac, M. Pavlovski, and Z. Obradovic, “Predicting complications of diabetes mellitus using advanced machine learning algorithms,” Journal of the American Medical Informatics Association, vol. 27, no. 9, 2020.
  • [21] A. Guo, K. W. Zhang, K. Reynolds, and R. E. Foraker, “Coronary heart disease and mortality following a breast cancer diagnosis,” BMC medical informatics and decision making, vol. 20, 2020.
  • [22] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin, “Attention is all you need,” in NIPS’17, 2017.
  • [23] D. Dong, H. Wu, W. He, D. Yu, and H. Wang, “Multi-task learning for multiple language translation,” in ACL’15, 2015.
  • [24] T. Chen, S. Kornblith, M. Norouzi, and G. Hinton, “A simple framework for contrastive learning of visual representations,” in ICML’20.   PMLR, 2020.
  • [25] L. Breiman, “Random forests,” Machine learning, vol. 45, no. 1, 2001.
  • [26] K. Cho, B. Van Merriënboer, C. Gulcehre, D. Bahdanau, F. Bougares, H. Schwenk, and Y. Bengio, “Learning phrase representations using rnn encoder-decoder for statistical machine translation,” in EMNLP’14, 2014.
  • [27] M. Ye, J. Luo, C. Xiao, and F. Ma, “Lsan: Modeling long-term dependencies and short-term correlations with hierarchical attention for risk prediction,” in CIKM’20, 2020.
  • [28] F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P. Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher, M. Perrot, and E. Duchesnay, “Scikit-learn: Machine learning in Python,” Journal of Machine Learning Research, vol. 12, 2011.
  • [29] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” in ICLR’15, 2015.
  • [30] H. C. Kenny and E. D. Abel, “Heart failure in type 2 diabetes mellitus: impact of glucose-lowering agents, heart failure therapies, and novel therapeutic strategies,” Circulation research, vol. 124, no. 1, 2019.
  • [31] N. Mikhail, M. S. Golub, and M. L. Tuck, “Obesity and hypertension,” Progress in cardiovascular diseases, vol. 42, no. 1, 1999.
  • [32] I. A. Ebong, D. C. Goff Jr, C. J. Rodriguez, H. Chen, and A. G. Bertoni, “Mechanisms of heart failure in obesity,” Obesity research & clinical practice, vol. 8, no. 6, 2014.
  • [33] M. Mosseri, R. Yarom, M. Gotsman, and Y. Hasin, “Histologic evidence for small-vessel coronary artery disease in patients with angina pectoris and patent large coronary arteries.” Circulation, vol. 74, no. 5, 1986.