Multi-task Learning via Adaptation to Similar Tasks for Mortality Prediction of Diverse Rare Diseases

04/11/2020 ∙ by Luchen Liu, et al. ∙ 0

Mortality prediction of diverse rare diseases using electronic health record (EHR) data is a crucial task for intelligent healthcare. However, data insufficiency and the clinical diversity of rare diseases make it hard for directly training deep learning models on individual disease data or all the data from different diseases. Mortality prediction for these patients with different diseases can be viewed as a multi-task learning problem with insufficient data and large task number. But the tasks with little training data also make it hard to train task-specific modules in multi-task learning models. To address the challenges of data insufficiency and task diversity, we propose an initialization-sharing multi-task learning method (Ada-Sit) which learns the parameter initialization for fast adaptation to dynamically measured similar tasks. We use Ada-Sit to train long short-term memory networks (LSTM) based prediction models on longitudinal EHR data. And experimental results demonstrate that the proposed model is effective for mortality prediction of diverse rare diseases.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Mortality prediction [Sharma2017Mortality] of diseases plays a crucial role in clinical work, which could help doctors to take early intervenes based on timely alert of patients’ adverse health status. With the immense accumulation of Electronic Health Records (EHR) available [eicu, johnson2016mimic], deep learning models [liu2018learning], typically requiring a large volume of data, have been developed and demonstrate state-of-the-art performance on mortality prediction of common diseases. However, mortality prediction of rare diseases is relatively unexplored in the domain of intelligent healthcare and personalized medicine.

Predicting mortality of rare diseases suffers from the problem of data insufficiency. A rare disease is the disease that affects a small percentage of the population and has a different disease mechanism from common ones. However, there are more than 300 million people worldwide living with one of the approximately 7,000 rare diseases (US organization Global Geneshttps://globalgenes.org/rare-list/). Therefore, there is always not enough volume of data for a specific rare disease. In some real-world data sources [johnson2016mimic], only tens of data samples on average could be collected for each rare disease.

Besides data insufficiency, clinical behavior diversity of these diseases is also another challenge for mortality prediction of rare diseases. Behaviors of different diseases vary a lot and make potential conflicts to the training of global deep learning models [hochreiter1997long, bai2018empirical, vaswani2017attention]. For example, the high heart rate raises the mortality risk of people with heart disease, but for patients having a cold, it is a common symptom that does not indicate the patient will be in danger. So the problem of data volume constraint of diverse rare diseases cannot be simply resolved by training a global model based on samples from all diseases.

Multi-task learning models [song2018attend, Suresh2018Learning] can be used to settle the problem of disease behavior diversity. Mortality prediction for each kind of rare disease is viewed as a task and multi-task learning is supposed to capture the task-specific characteristics as well as the shared information of all tasks. However, the rare disease tasks with little training data make it hard to train specific modules in multi-task learning models, and further utilizing the shared characteristics of similar tasks is difficult.

Meta-learning methods can learn meta-knowledge of training models, which makes it possible to learn fast with few samples, such as few-shot learning [finn2017model]. To build a better multi-task model suitable for tasks with little training data, we bring in the idea of fast adaptation in meta learning, which learns a shared initialization as meta-knowledge for adaptation to new tasks. However, since the fast adaptation method adapts shared initialization to each task independently, it cannot directly take into account the relationship of similar tasks, which is important and can provide useful information to enhance multi-task learning [evgeniou2005learning, jacob2009clustered].

Method Examples Challenges
Data Insufficiency Task Similarity Task Diversity
Global Models LSTM [hochreiter1997long], TCN [bai2018empirical]
Multi-task Models Multi-SAnD [song2018attend], Multi-Dense [Suresh2018Learning]
Meta-learning Models MAML [finn2017model]
Our Model Ada-SiT
Table 1: Comparing the three kinds of methods in terms of their ability to handle three main challenges in mortality prediction for rare diseases

To deal with the above-mentioned challenges (summarised in Table 1), we propose an initialization-shared multi-task learning method, named as Ada-SiT (Adaptation to Similar Task), in which the task similarity is dynamically measured in the process of meta-training according to our new definition. Therefore, the task similarity, as a part of learned meta-knowledge, can enhance the fast adaptation procedure of genetic meta-learning models. Moreover, Ada-SiT is model-agnostic and can employ all existing deep learning based approaches as the basic predictive model. Experimental results on real medical datasets demonstrate that the proposed model is able to make similar tasks cooperate in initialization-shared multi-task learning, and it outperforms state-of-the-art global models as well as multi-task methods for mortality prediction of diverse rare diseases.

It is worthwhile to highlight the contributions of the proposed model as follow:

  • To the best of our knowledge, this is the first attempt to simultaneously tackle the challenge of disease diversity and data insufficiency in mortality prediction of rare diseases.

  • We propose a novel initialization-shared multi-task learning method Ada-SiT, which can utilize information of task similarity for adaptation to each small sample-size task.

2 Related Works

2.1 Deep Learning for Healthcare

The accumulation of Electronic Health Records (EHR) has enabled research on deep learning methods for healthcare [liu2019learning, wang2019predictive, liu2019early]

. Multi-layer Perceptron (MLP) 

[cheKDD15]

, Convolutional Neural Network (CNN) 

[suo2017personalized]

and Recurrent Neural Network (RNN) 

[choi2016retain, suo2017multi, liu2018learning] have been used in healthcare domain. Among these methods, there are many works on mortality prediction. The good performance of these models depends on a large volume of EHR data, which cannot be satisfied in our scenario of mortality prediction for diverse rare diseases. As a result, these models cannot make precise mortality predictions for patients with different rare diseases. Our work is suitable for these settings because it simultaneously tackles the challenges of disease diversity and data insufficiency. Furthermore, our method is a general framework and can be applied to train deep learning models to improve their performance.

2.2 Multi-task Learning

Multi-task learning is an efficient method to improve the performance by jointly learning multiple related tasks. In deep multi-task learning models, the information sharing mechanism is based on specific network structures, including shared layers [long2015learning], shared functions [chen2018meta] and additional constraints  [misra2016cross]. However, task similarity cannot be directly interpreted in these models . We propose a model-agnostic multi-task learning method which can share the parameter initialization for fast adaptation to each task and task similarity can be dynamically measured in the training prossess.

In the healthcare domain, multi-task learning is used for prediction of various clinical events [harutyunyan2017multitask], mortality prediction of multiple patient cohorts [Suresh2018Learning] and patient-specific diagnosis [nori2017learning], in which the ”tasks” have different definitions. Similar to the paper [Suresh2018Learning], our work also treats the mortality prediction of a certain patient cohort as a task. However, the method proposed by Suresh [Suresh2018Learning] is suitable for a small number of tasks with a large volume of data, meanwhile, our method on mortality prediction for rare diseases is designed to deal with hundreds of tasks with insufficient data.

2.3 Optimization-based Meta Learning

To solve the problem of data insufficiency and task diversity, our method borrows the idea behind optimization-based meta-learning [finn2017model, andrychowicz2016learning], which can adapt to new environments with a few training samples by modifying the parameter optimization process. MAML (Model Agnostic Meta Learning) [finn2017model] uses fast adaptation to find a good initialization for the parameters and adapt it to new tasks. Our work is similar to MAML for using fast adaptation to get parameters of each task, but different from it in two ways: First, the objective of MAML is to learn a good parameter initialization for fast adapting to new tasks but our work is to find good model parameters for each given task. Second, our work measures task similarity in model space dynamically and uses samples in similar tasks to assist the adaptation to each task while MAML does fast adaptations to each task independently.

In clinical scenario, MetaPred [Zhang2019MetaPred] use MAML for clinical risk prediction with limited patient electronic health records. Its task is similar to ours but its method is different from ours in two ways: First, it trains a parameter initialization on source domain and simulated target domain via fast adaptation, but our method can learn from multiple small target domains without the source domain knowledge. Second, like MAML, MetaGred also doesn’t consider task similarity.

3 Data and Task Descriptions

We give the notations and data descriptions of the predictive tasks in the following.

3.1 Heterogeneous Temporal Events in EHR data

The input of each mortality prediction task is a given episode of patient EHR, which could be represented as heterogeneous temporal events [liu2018learning]: . in this sequence is a tuple with four element: , where is the clinical event type, and are the categorical and numerical attributes of , and is the record time of .

3.2 Multi-Task Mortality Prediction

The mortality prediction of each rare disease is defined as a task. Specifically, assuming that there are M diseases, we refer as the corpus of the i-th mortality prediction task for patients of rare disease with samples:

(1)

where and denote the k-th sample and its label respectively in the i-th task.

Specifically, is a given episode of a patient’s EHR data, represented as heterogeneous temporal events, and is the binary label indicating whether a patient will die in 24 hours.

3.3 Patient Cohort Setup

3.3.1 Heterogeneous Temporal Event Datasets

We set up two heterogeneous temporal event datasets based on MIMIC-III [johnson2016mimic] database and eICU [eicu] database. MIMIC-III is a large, freely-available database comprising health data of patients in critical care units from Beth Israel Deaconess Medical Center between 2001 and 2012, and eICU is populated with data from a combination of many critical care units throughout the continental United States between 2014 and 2015.

The two datasets have the same data preprocessing framework: For each patient, we select all the events with their features from the original database and arrange the events in the temporal order. The descriptions of the selected events in eICU are listed in Table 2. The details of the events in MIMIC-III can be found in [amia]. Then we annotate the mortality label for each patient event sequence.

Table Name in eICU Description
lab Laboratory tests mapped to a standard set of measurements. E.g. labTypeID, labResult
intakeoutput Intake and output recorded for patients. E.g. intakeTotal, outputTotal, dialysisTotal
medication Active medication orders for patients. E.g. drugHiclSeqno, dosage
infusiondrug Details of drug infusions. E.g. drugRate, infusionRate, drugAmount
careplan Documentation relating to care planning. E.g. cplGeneralID, cplItemValue
admissiondrug Details of medications that a patient was taking prior to admission to the ICU. E.g. drugUnit
nursecharting Information entered in a semi-structured form by the nurse. E.g. nursingchartcelltypecat
physicalexam Patients’ results of the physical exam. E.g. physicalExamText, physicalExamValue
diagnosis Diagnosis information of Patients. E.g. ICD9Code, diagnosisPriority
respiratorycare Information related to respiratory care for patient. E.g. airwayType, airwaySize, cuffPressure
allergy Details of patient allergies. E.g. allergyType
Table 2: The tables in eICU to construct heterogeneous temporal events

3.3.2 Rare Disease Selection

We reorganize the heterogeneous temporal event datasets to get two rare disease datasets, MiniMIMIC and MiniEICU.

For each ICD (International Classification of Diagnose http://www.icd-code.org) code in MIMIC-III and eICU, we calculate its sample size (i.e. the number of patients with this code). We select 858 ICD codes with less than 40 samples in MIMIC-III and 70 ICD codes with less than 100 samples in eICU as rare diseases. The heterogeneous temporal event sequences of patients with selected rare disease form . So is the task list and is the task list .

The statistics of and are summarized in Table 5.2.

Each is split into 3 parts with fixed proportions, namely (70%), (10%) and

(20%). The validation set is used for conducting “early stop” and selecting hyper-parameters. The results of the evaluation metrics on the test set and their stander variations are used to compare different models.

4 Methodology

In this section, we begin with some basic notations for the adaptation to a single task and the backbone deep network for mortality prediction based on EHR data. Then we introduce the framework of our method Ada-Sit (Adaptation to similar tasks), which could be applied to train multiple mortality predictive models for different rare disease.

4.1 Adaptation to a Single Task

For a task of a given disease and given initial parameters (either random or learned), we formulate the learning process of model parameters

as minimizing the loss function of model parameters on the data of the given task from the initialization

.

(2)

We assume is the mortality rate predicted by model with parameters . The loss function of model parameters on task corpus is defined using the cross entropy between the model output and the true label of the outcome:

(3)

In this work, the mortality prediction model based on heterogeneous temporal events in EHR data is mainly composed of attributed event embedding [liu2018learning]

and long short-term memory (LSTM). Firstly, clincial events with attributes are embedded into vectors via information fusion of their type, categorical attributes and numerical attributes. After the attributed event embedding module, the temporal dependencies encoded in the sequence of embedded vectors are then captured by long short-term memory (LSTM)

[hochreiter1997long], which outputs the prediction results for mortality with a sigmoid layer at the last LSTM cell. The mortality prediction model is illustrated in the middle of Figur 2.

4.2 Adaptation to Similar Tasks

Figure 1: (A) Generic Architecture of Multi-Task Learning. The shared module (in blue), whose output will be taken as the input of specific modules, is shared between different tasks. (B) Generic Architecture of Meta-Learning based Multi-Task Learning. The parameter initialization in blue, which will be adapted to each specific task , is shared between all tasks (C) Architecture of our Ada-SiT (Adaptation to Similar Task). The shared initialization will be adapted to similar tasks of each task , resulting in a predictive model for each task.

The architecture of our proposed multi-task learning method Ada-SiT is represented in Figure 1, where Ada-Sit is compared with generic multi-task methods and generic meta-learning based multi-task methods. In the following, we first introduce the architecture of Ada-SiT, and then the task similarity measurement module, which is a key component of Ada-SiT. The overall framework of how to train mortality prediction model with Ada-Sit is illustrated in Figure 2.

Figure 2: Training Mortality Prediction Model with Ada-Sit

4.2.1 Architecture of Ada-SiT

Fast adaptation, one of the meta-learning methods, is applied into Ada-SiT for the multi-task learning scenario. Ada-SiT differs from the original fast adaptation  [finn2017model] which separately adapts to each task. Ada-SiT can dynamically measure task similarity and learn to adapt to multiple similar tasks.

The idea of fast adaptation is to learn a good parameter initialization for fast adaptation to new tasks. Specifically, in our scenario of multi-task learning, it means using patient data from all the disease tasks to find a good initial parameter initialization and mortality prediction model parameters of each disease which is adapted from the found initial parameters .

The initialization is learned by repeatedly simulating scenarios of mortality prediction of each disease with its similar diseases. We achieve this goal by defining the meta-objective function as:

(4)

where is the extended sample set composed of tasks similar with . Details of will be described in the next subsection.

We maximize the meta-objective function using stochastic approximation with gradient descent. For each epoch, we find similar tasks

for each task, and then independently sample two samples subsets ( and ) from the training set of the task . The former is used to simulate the learning process of mortality prediction models, and the latter is used to evaluate the precision of the learned models for the updating of the shared initialization. Here, a single-step gradient descend is applied in the training simulation:

(5)

where is the learning rate and represents the loss function calculated from sample set from .

Next, we evaluate the updated task parameters on . The gradient computed from the evaluation, referred to as meta-gradient, is used to update the initial parameters . Gradients from a batch of tasks are aggregated to updating as follow:

(6)

where is the meta learning rate, and represents the loss function calculated from sample set from .

In the process of the calculation and approximation of the meta-gradient by the chain rule, the second-order derivative term can be ignored without much accuracy loss 

[finn2017model]. And the meta-gradient can be approximated as the following simplified gradient:

(7)

where the term including , the Hessian matrix, a square matrix of second-order partial derivatives of the loss function at , is ignored.

As our goal is multi-task learning via adaptation to similar tasks, model parameters of each task are adapted from the newly updated initialization parameters at the end of each iteration. These parameters of tasks are comparable in the model space, and will be used for calculating the task similarity in model space in the next iteration.

4.2.2 Task Similarity Measurement in Model Space

To measure the similarity of all tasks in terms of clinical behavior, we define the similarity of tasks in the model space, where each predictive model for its corresponding task is represented as a vector composed of all the parameters.

Formally, the similar tasks of is defined as :

(8)

where and are the parameters of corresponding model of and , is the initial parameters. reflects function of the angle between gradient directions of and , and is threshold of the function.

The models of similar diseases also have similar gradient directions when they are adapted from the initialization. And the big (close to 1) value of included angle between two gradient directions indicates the tasks are similar.

Notice that a natural alternative way to get similar tasks is selecting k nearest neighbors in the model space. However, the absolute distance of models is more meaningful than relative distance, because the distance of models is generated by the gradient descent of the adaptation process from a common initialization. So selecting models in the neighborhood of a certain model as its similar models are more suitable in our Ada-SiT method, which is demonstrated by the experimental results in Section 5.4.

5 Results and Discussions

5.1 Comparing Methods and Experimental Settings

We compare Ada-SiT to both global single-task learning methods and multi-task learning methods for mortality prediction. The data size of each task is too small to train separate single-task models, so these baselines have not been included.

The global single-task learning models are trained by all the patients in the training set.

LSTM LSTM [hochreiter1997long]

is used to learn the representation of the heterogeneous event sequence for each patient. Binary predictions for mortality based on the learned representations can be generated with a logistic regression layer.

TCN The architect for prediction is the same as LSTM, except that TCN [bai2018empirical] is used to learn patient representation vectors from the heterogeneous event sequences instead of LSTM.

The multi-task learning methods use to train the model for each

and get prediction results such as predicted label and probability on

.

Multi-SAnD Multi-SAnD [song2018attend] is a Transformer-based multi-task learning method which uses the weighted sum of loss functions on all tasks for the loss function.

MMoE MMoE [ma2018modeling] is a Multi-gate Mixture-of-Experts model which shares the expert submodel across all tasks and has a gating network trained to optimize each task.

Multi-Dense Multi-Dense is proposed by Suresh [Suresh2018Learning]. It has a shared LSTM layer for representation learning, followed by task-specific dense layers and logistic regression output layers.

The models in this section are implemented with Tensorflow 

[abadi2016tensorflow] and trained with Adam. In Ada-SiT, we set , where the function threshold is 0.7, for the task similarity measurement. and are 0.0005 and 0.001 respectively.

5.2 Evaluation Metrics

The data for target prediction tasks are imbalanced labeled. So metrics for binary labels such as accuracy are not suitable for measuring the performance. Similar to the work [choi2016retain], we adopt AUC

(the area under ROC curves (Receiver Operating Characteristic curves)) and

AP (the area under PRC (Precision-Recall curves)) for evaluation. They both reflect the overall quality of predicted scores at each decision time.

Name MiniMIMIC MiniEICU # of tasks 858 70 # of samples 16610 7000 positive sample rate (mortality rate) 7% 13% max # of samples per task 40 100 min # of samples per task 10 100 mean # of samples per task 19.36 100 Table 3: Statistics of Datasets Figure 3: Visualization of Tasks in the Model Space

5.3 Quantitative Results

Model Class Model MiniMIMIC MiniEICU
AUC AP AUC AP
Global Single-task LSTM [hochreiter1997long] 0.8162 (0.0026) 0.3830 (0.0055) 0.6642 (0.0227) 0.2692 (0.0193)
TCN [bai2018empirical] 0.8008 (0.0024) 0.4120 (0.0011) 0.6107 (0.0055) 0.1945 (0.0052)
Multi-task Multi-SAnD [song2018attend] 0.8036 (0.0161) 0.2754 (0.0063) 0.6215 (0.0075) 0.1592 (0.0016)
MMoE [ma2018modeling] 0.7181 (0.0117) 0.2195 (0.0097) 0.6300 (0.0023) 0.1364 (0.0030)
Multi-Dense [Suresh2018Learning] 0.8325 (0.0036) 0.3997 (0.0096) 0.6730 (0.0071) 0.1147 (0.0039)
Ours Ada-SiT 0.8729 (0.0112) 0.4543 (0.0241) 0.6746 (0.0090) 0.2961 (0.0103)
Table 4: performance of different models on and

Table 4 shows the AUC and AP of Ada-SiT, global single-task models, and multi-task models on and . From the results in Table 4, we draw the following conclusions:

First, Ada-SiT can significantly improve the performance of global single-task learning methods. On both datasets, Ada-SiT performs better than LSTM and TCN. For example, on , Ada-SiT improves AUC and AP by around 6.9% and 19.1% respectively compared to LSTM. We can conclude that Ada-SiT can capture specific characteristics of diverse tasks without being interfered by data conflicts.

Second, Ada-SiT outperforms the compared multi-task learning methods. For example, on , Ada-SiT improves AUC and AP of Multi-SAnD by 8.6% and 65.0% respectively. On , it improves the AUC of Multi-SAnD and MMoE by 8.5% and 7.1% respectively. It should be noted that most of the multi-task baselines do not perform better than the global single-task baselines. The possible cause is that task-specific parameters of multi-task models cannot be trained well because of data insufficiency of each task. We can conclude that Ada-SiT has a more robust information-sharing mechanism among tasks on small-size dataset compared to the traditional multi-task learning baselines.

5.4 Ablation Experiments of Task Similarity Measurement

To evaluate the effect of task similarity measuring in Ada-SiT, we vary this module while remaining other parts of the model identified in this section. We implement the following variants of Ada-SiT:

Ada-SiT Ada-SiT is Ada-SiT without similar task measurement (i.e. ), nearly the same as MAML [finn2017model]

Ada-SiT (Static) According to the work  [ruder2017learning], many static features can be used to measure task similarity. In Ada-SiT (Static), we choose the mortality rate as the static feature in the clinical scenarios and use it to measure task similarity instead of the proposed similarity measurement.

Ada-SiT (KNN)

Ada-SiT (KNN) selects k nearest neighbors instead of neighbors within a certain distance for

while finding similar tasks.

Methods AUC AP
Ada-SiT (MAML) 0.8577 (0.0015) 0.3936 (0.0025)
Ada-SiT (Static) 0.8474 (0.0123) 0.4143 (0.0144)
Ada-SiT (KNN) 0.8264 (0.0110) 0.4059 (0.0112)
Ada-SiT 0.8729 (0.0112) 0.4543 (0.0241)
Table 5: Ablation study of task similarity measurement

Table 5 shows the results of the ablation experiments of task similarity measurement. We can draw the following conclusions: First, the information in similar tasks can improve the performance of fast adaptation. Ada-SiT (KNN) and Ada-SiT improve the AP of Ada-SiT. Ada-SiT also improves the AUC of Ada-SiT. Second, the task similarity measurement in model space outperforms the compared static measurement. Ada-SiT improves the AUC and AP of Ada-SiT (Static) by 3.0% and 9.7% respectively. It is because traditional task similarity measurements via static features only leverage the metadata of tasks, but our measurement in model space can find potential information from the mapping of samples and labels. Third, finding neighbors within a certain distance as is the most suitable way to get similar tasks. For example, Ada-SiT improves the AUC of Ada-SiT (KNN) by around 5.6%. It is noteworthy that the AUC of Ada-SiT (KNN) is even lower than Ada-SiT. It is possibly because some tasks in the k nearest neighbors of may be far from in model space and they interfere with the fast adaptation process of .

5.5 Relationship between Task Similarity and Mortality Rate

There is strong correlation between task similarity and mortality rate. By treating each model’s parameter vector as a point in the model space, we use t-SNE [maaten2008visualizing] to visualize similar task clusters in Figure 5.2. In Figure 5.2, the two task clusters represent two types of rare diseases. The average mortality rate of diseases in the blue cluster is 0.6% and that in the yellow cluster is 32.1%. Meanwhile, the total average mortality rate of is 7%. We can see that the mortality rate is the main factor to determine task similarity. It suggests that our task similarity measurement module is reasonable and consistent with the clinical knowledge because diseases with a lower mortality rate and life-threatening diseases with a high mortality rate have different clinical behavior.

6 Conclusion

In this paper, we propose a novel method Ada-SiT for learning predictive models for diverse tasks with insufficient data. Ada-SiT has a new task similarity measurement method and a new knowledge-sharing schema, where the shared initialization is learned to fast adapt to similar tasks. Experiment results show that our method is suitable for mortality prediction of diverse rare diseases, and can improve the performance compared to global single-task models and genetic multi-task models.

0.7 bibliography[References]

#1