RDPD: Rich Data Helps Poor Data via Imitation

09/06/2018 ∙ by Shenda Hong, et al. ∙ ibm Georgia Institute of Technology Peking University 0

In many situations, we have both rich- and poor- data environments: in a rich-data environment (e.g., intensive care units), we have high-quality multi-modality data. On the other hand, in a poor-data environment (e.g., at home), we often only have access to a single data modality with low quality. How can we learn an accurate and efficient model for the poor-data environment by leveraging multi-modality data from the rich-data environment? In this work, we propose a knowledge distillation model RDPD to enhance a small model trained on poor data with a complex model trained on rich data. In an end-to-end fashion, RDPD trains a student model built on a single modality data (poor data) to imitate the behavior and performance of a teacher model from multimodal data (rich data) via jointly optimizing the combined loss of attention imitation and target imitation. We evaluated RDPD on three real-world datasets. RDPD consistently outperformed all baselines across all three datasets, especially achieving the greatest performance improvement over a standard neural network model trained on the common features (Direct model) by 24.56 distillation model by 5.91

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

Introduction

Many rich-data

environments encompass multiple data modalities. For example, multiple motion sensors in a lab can collect activity signals from various locations of a human body where signals generated from each location can be viewed as one modality. Multiple leads for Electrocardiogram (ECG) signals in hospital are used for diagnosing heart diseases, of which each lead is considered a modality. Multiple physiological signals are measured in intensive care units (ICU) where each type of measure is a modality. A series of recent studies have confirmed that finding patterns among rich multimodal data can increase the accuracy of diagnosis, prediction, and overall performance of the deep learning models 

[24].

Despite the promises that rich multimodal data bring us, in practice we have more poor-data environments with data from fewer modalities of limited quality. For example, unlike in a rich-data environment such as hospitals where patients place multiple electrons to collect 12-lead ECG signals, in everyday home monitoring devices often only measure lead I ECG signal from arms. Although deep learning models often perform well in rich-data environment, their performance on poor-data environment is less impressive due to limited data modality and lower quality [22].

We argue that given both rich- and poor-data from similar contexts, the models built on rich multi-modal data can help improve the other model built on poor data with less modalities or even single modality. For example, a heart disease detection model trained on 12 ECG channels in hospital can help improve a similar heart disease detection model that trains on ECG signals from a single-channel at home.

The recent development of mimic learning or knowledge distillation [11, 1, 18] has provided a way of transferring information from a complex model (teacher model) to a simpler model (student model). Knowledge distillation or mimic learning essentially compresses the knowledge learned from a complex model into a simpler model that is much easier to deploy. However they often require the same data for teacher and student models. Domain adaptation techniques address the problem of learning models on some source data distribution that generalize to a different target distribution. Deep learning based domain adaptation methods have focused mainly on learning domain-invariant representations [10, 5, 3]. However they often need to be trained jointly on source and target domain data and are therefore unappealing to the settings when the target data source is unavailable during training.

In this paper, we propose RDPD (Rich Data to Poor Data) to build accurate and efficient models for poor data with the help of rich data. In particular, RDPD transfers knowledge from a teacher model trained on rich data to a student model operating on poor data by directly leveraging multimodal data in the training process. Given teacher model along with attention weights learned from multimodal data, RDPD is trained end-to-end for the student model operating on poor data to imitate the behavior (attention imitation) and performance (target imitation) of teacher model. In particular, RDPD jointly optimize the combined loss of attention imitation and target imitation. The loss of target imitation can utilize both hard labels from the data and soft labels provided by the teacher model. Here are the main contributions of this work:

  • We formally define the learning task from rich data to poor data, which has many real-world applications including healthcare.

  • We propose RDPD algorithm based on mimic learning, which takes a joint optimization approach to transfer a teacher model from rich data to a student model on poor data. The resulting model is also much lightweight than the original teacher model thus can be more easily deployed.

  • RDPD consistently outperformed all baselines across multiple datasets, especially demonstrating the greatest performance improvement over the Direct model trained on common features between rich and poor data by on PR-AUC and on ROC-AUC, and over the standard distillation model in [11] by on PR-AUC and on ROC-AUC.

Method

In this section, we will first describe the task, and then introduce the design of RDPD (shown in Fig. 1).

Task Description

Figure 1: The framework of RDPD. Given teacher model along with attention weights learned from rich data, RDPD trains the student model on poor data while imitating the behavior and performance of teacher model. In particular, RDPD jointly optimize the combined loss of attention imitation (behavior) and target imitation (performance). The loss of target imitation also concerns both hard labels from data and soft labels provided by the teacher model.

Consider data collected via continuous time series, given a teacher model trained from rich data environment, we want to teach a student model running on only poor data. And we hope the student model could benefit from the information contained in rich data via the teacher model, by imitating the teacher model in terms of learning outcome and the learning process. In this work, we call the former objective as target imitation, and the latter one as behavior imitation. The target imitation can be achieved by imitating the final predictions (i.e., soft labels) of teacher model. While the behavior imitation can be achieved by imitating the attention weights over temporal time series.

Mathematically, denote as the multi-modal rich data with modalities that is available in training phase, and as the poor data with modalities that is available in both training and testing phases. Here the modalities in are a subset of , and ; and share the same labels . Our task is to build a student model which only takes as input, and will benefit from knowledge transferred from .

Overview For RDPD, the student model trained on poor data will imitate teacher model trained on rich data and hard labels in both intermediate learning behavior and final learning performance. The imitation of learning behavior is achieved by optimizing information loss from distribution of attention in student model to distribution of attention in teacher model. While the performance imitation is done by jointly optimizing hard label, soft label and and their trainable combination. Fig. 2 illustrates the learning process of RDPD. In the following we detail each step of RDPD.

Figure 2: The computation process of training teacher model and imitation in RDPD.

Training Teacher Model

Although RDPD can be applied on time series in general, in this paper we only consider regularly sampled continuous time series (e.g., sensor data). Assume a patient has time series from modalities, for time series in each modality with length , we split into segments at length , thus . We denote multi-modal segmented input time series as .

We applied stacked 1-D convolutional neural networks (CNN) on each segment and recurrent neural networks (RNN) across segments. Such a design has been demonstrated effective in many previous studies on multivariate time series modeling  

[19, 6]. In detail, we apply 1-D CNN with mean pooling on each segment as given by Eq. 1

. Parameters including number of filters, filter size and stride in CNN are shared among segments

, and vary across different datasets. Details are shown in the Experiment Setup section.

(1)

Then, we concatenate all convolved and pooled segments to get , where is the number of filters in . Next we applied an RNN layer on and denote the output as such that . And , where is the number of hidden units in RNN layer. Here we use the widely-applied self attention mechanism  [17] as it is a natural choice to get better results by taking advantage of the correlations or importance of segments. It also generates attention weights that could represent teacher’s behaviors on each segment. The attention weights are calculated by Eq. 2.

(2)

where , . We then multiplied the RNN output with corresponding attention weights . The weighted output is given by Eq. 3.

(3)

where . Finally, the weighted output is further transformed by a dense layer with weights

to output logits

.

(4)

For simplicity, we can summarize from Eq.1 to Eq.4 to represent the teacher model as in Eq. 5: takes as inputs and outputs logits and attention weights .

(5)

The objective function of the teacher model measures prediction accuracy, and also provides knowledge to student model. Typically,

are transformed by softmax as final predicted probabilities, which can be used as distilled knowledge for student model to imitate. However, sharp distribution (e.g, hard labels) will be less informative. To alleviate this issue, we follow the idea in  

[11]

to produce more informative soft labels. Compared with hard label, the soft label imitation has much smoother probability distribution over classes, thus contains richer (larger entropy) informations. Concretely, we modify classic softmax to

by dividing original logits with a predefined hyper-parameter (larger than 1). is usually referred to as Temperature. The modified softmax (shows th soft probability) is given by Eq. 6 and the soft predictions are given by Eq. 7.

(6)
(7)

Finally, we use cross-entropy loss as prediction loss (in Eq. 8) to measure the difference between soft predictions and ground truth . We optimize teacher model via minimizing .

(8)

Imitating Attentions and Targets

After training teacher model on rich data, we now describe the imitation process for the student model. For attention imitation, we mean to mimic attention weights. For target imitation, the student model imitates the following components: 1) soft label that is more informative, 2) hard label that could improve performance (according to  [11]), and 3) a trainable combination of both soft label and hard label. Again, we start with constructing the student model using a CNN + RNN architecture, but with fewer filters in CNN and fewer hidden units in RNN. In our experiment, we roughly keep the proportion of hyper-parameters in teacher model to student model the same as the proportion of to using , where and is the number of filters of CNN in teacher model and student model. Also, , where and is the number of hidden units of RNN in teacher model and student model. Similar to Eq.5, takes as inputs and outputs logits and attention weights as in Eq. 9.

(9)

Attention Imitation

In Eq.2 we define attention weights to represent the influence of different time segments to the final predictions. We assume that the attention behavior of student model should resemble that of teacher model, and formulate the attention imitation as below. Given Eq.5 and Eq.9, to enforce and to have similar distributions, we minimize the Kullback-Leibler (KL) divergence given by Eq. 10 to measure the information loss from distribution of attention in student model to distribution of attention in teacher model .

(10)

Imitating Hard Labels

For hard label imitation, we optimize the student model by minimizing cross entropy loss (in Eq. 11) that measures the difference between predicted target values and ground truth values , where is the number of target classes.

(11)

where .

Imitating Soft Labels

Given soft labels from , we produce soft predictions by the same temperature on softmax in student model . Then, we optimize a cross entropy loss (in Eq. 12) that measures the differences between student and teacher.

(12)

Here, is defined in Eq.7. . Since the magnitudes of gradients in Eq.12 is scaled by as we divided logits by , we should multiply the soft label imitation loss by to keep comparable gradient during implementation.

Imitating Combined Label

While hard labels provide certain prediction outcomes and soft labels provide probabilistic predictions, the two labels may even be opposite. To resolve the gap between the two labels, a reasonable solution is to combine them to yield uncertain prediction (probabilities of each class). Besides, while hard label imitation helps student model learn more information from data, soft label imitation transfer more knowledge from the teacher model (smoother distribution), each will lead to either more bias (comes from data) or more variance (comes from model). To leverage their benefits and make them complement each other, we propose to minimize a linear combination of hard labels and soft labels, denoted as

as the follows:

(13)

where are learnable parameters. For the combined imitation, we also use cross entropy loss (in Eq. 14) to define the loss between and ground truth .

(14)

Joint Optimization

Finally, for the student model to imitate attentions and targets simultaneously, we jointly optimize all loss functions above. Since they are computed using cross entropy loss, and we have rectify to get comparable loss values. Here, we simply summed them up to get the final objective function

given by Eq. 15. We summarize the RDPD method in Algorithm 1.

(15)
1:  Build teacher model
2:  Compute
3:      
4:      
5:  while not convergence do
6:     Update weights of by optimizing using back-propagation
7:  end while
8:  Build student model
9:  Compute
10:      ,
11:      
12:      
13:      
14:      
15:      
16:      
17:  while not convergence do
18:     Update weights of by optimizing using back-propagation
19:  end while
Algorithm 1 RDPD ()

Experiments

In this section, we evaluated the performance and effectiveness of RDPD.

Experiment Setup

Datasets

We used the following datasets in performance evaluation. Data statistics are summarized in Table 1.

PAMAP2 Physical Activity Monitoring Data Set (PAMAP2)  [21] contains 52 channels of sensor signals of 9 subjects wearing 3 inertial measurement units (IMU, 100Hz) and a heart rate monitor (HR, 9Hz). The average length of each subject is about 42k points. We down-sample the signals to 50 Hz and choose for experiment. We followed the “frame-by-frame analysis” in [21]

to pre-process the time series with sliding windows of 5.12 seconds duration and 1 second stepping between adjacent windows. The task is to classify signals into one of the

different physical activities (e.g., walking, running, standing, etc.). In our experiment, we choose data of subject 105 for validation, subject 101 for testing, and others for training.

The PTB Diagnostic ECG Database (PTBDB) includes 15 channels of ECG signals collected from controls and patients of heart diseases  [4]. The database contains 549 records from 290 subjects. We down-sample the signals to 200 Hz and choose for experiment. Similar to PAMAP2, we pre-processed the data using “frame-by-frame analysis” with sliding windows of 10 seconds duration and 5 second stepping between adjacent windows. Our task is to classify signals into one of the 6 patient groups. In our experiment, we random divided the data into training (80%), validation (10%) and test (10%) sets by subjects.

The Medical Information Mart for Intensive Care (MIMIC-III) is collected on over ICU patients at the Beth Israel Deaconess Medical Center (BIDMC) from June 2001 to October 2012  [13]. In our experiment, we focus on patients with following diseases: 1) acute myocardial infarction, 2) chronic ischemic heart disease, 3) heart failure, 4) intracerebral hemorrhage, 5) specified procedures complications, 6) lung diseases,7) endocardium diseases, and 8) septicaemia, in total subjects. In detail, we extract 6 vital sign time series of the first 48 hours including heart rate (HR), Respiratory Rate (RR), Blood Pressure mean, Blood Pressure systolic, Blood Pressure diastolic and SpO2. We resample the time series to 1 point per hour and choose for experiment. Our task is to classify vital sign series into one of the 8 diseases. In our experiment, we random divided the data into training (80%), validation (10%) and test (10%) sets by patients.

PAMAP2 PTBDB MIMIC-III
# subjects 9 290 9,488
# classes 12 6 8
# attributes 52 15 6
Total time series length 2,872,533 59,619,455 455,424
Sample Frequency 100 Hz (IMU) 1,000 Hz 1 per hour
9 Hz (HR)
Table 1: Statistics of Datasets

Evaluations

Performance was measured by the Area under the Receiver Operating Characteristic (ROC-AUC), Area under the Precision-Recall Curve (PR-AUC), and macro F1 score (macro-F1). ROC-AUC and PR-AUC are evaluated between predicted probabilities and ground truth. The PR-AUC is considered a better measure for imbalanced data with much more negative samples like our setting [7]. Macro-F1 is a commonly used with threshold , which determine whether a given probability is predicted as (larger than threshold) or (smaller than threshold).

Implementation Details

Models are trained with the mini-batch of 128 samples for 200 iterations, which was a sufficient number of iterations for achieving the best performance for the classification task. The final model was selected using early stopping criteria on validation set. We then tested each model for 10 times using different random seeds, and report their mean values with standard deviation. All models were implemented in PyTorch version 0.5.0., and trained with a system equipped with 64GB RAM, 12 Intel Core i7-6850K 3.60GHz CPUs and Nvidia GeForce GTX 1080. All models were optimized using Adam  

[14], with the learning rate set to 0.001.

Comparative Methods

  • Teacher: Teacher model is trained and tested on all channels. The model has better accuracy, a much heavier model architecture, and is only available for in-hospital setting where all channels of signals are available. It serves as an empirical upper bound of performance.

  • Direct: Direct model is build on the partially observed data using RCNN, without attention imitation and soft label imitation. This model is equivalent to

  • Knowledge Distillation (KD): KD  [11] model is constructed on the partially observed data, with soft label imitation and hard label imitation. This model is equivalent to

  • : The reduced version of RDPD without attention imitation. And the objective function would be

  • : The reduced version of RDPD without combined labels. This model is equivalent to KD model with attention imitation. And the objective function would be

  • RDPD: Our whole model contains all proposed imitations. Using as objective function.

For all models, we use 1 layer 1-D CNN and 1 layer Bi-directional LSTM. In teacher model, for PAMAP2, the number of filters is set to 64, filter size is set to 8, stride is set to 4 and the number of hidden units is set to 32. For PTBDB, they are set to 128, 32, 8, 32 respectively. For MIMIC-III, they are set to 64, 4, 2, 32 respectively. In RDPD and compared baselines, since they have less input modalities, they have smaller number of CNN filters and RNN hidden units which is set proportionally as introduced before. However, the data length remains the same, so their filter size and stride keep unchanged. is set to 5 for PAMAP2 and PTBDB, and set to 2.5 for MIMIC-III. The effects and chosen of is shown in the Effect of Temperature section.

Experiments

To provide extensive evaluation of RDPD, we will compare it with baselines in the following aspects. . Classification performance; .̱ Reduction of model complexity; .̧ Evaluation against difference training data size; .̣ Evaluation for the setting of low quality poor-data. . Effect of temperature.

Results

a. Classification Performance

We compared the results of RDPD against other baselines and the reduced version of RDPD in Table 2 (PAMAP2 dataset), Table 3 (PTBDB dataset) and Table 4 (MIMIC-III dataset). RDPD outperformed other methods (except Teacher) in most cases and demonstrated the proposed attention imitation and target imitation successfully improved performance of student model. The teacher model performs best among all methods since it is trained using a full datasets with multiple modalities. It serves an empirical upper bound of the performance. In Table 3, RDPD works better than its reduced version in PR-AUC and F1-score but not ROC-AUC. The reason is that classes in PTBDB dataset is very imbalanced, some occasional samples in rare classes distort the final result.

Data Method ROC-AUC PR-AUC macro-F1
All Teacher 0.928 0.014 0.708 0.039 0.608 0.045
Wrist Direct 0.800 0.032 0.452 0.051 0.376 0.049
Distill 0.825 0.020 0.469 0.052 0.380 0.060
0.837 0.025 0.491 0.037 0.406 0.053
0.836 0.018 0.478 0.038 0.401 0.049
RDPD 0.838 0.012 0.491 0.045 0.425 0.057
Chest Direct 0.836 0.035 0.519 0.065 0.449 0.069
Distill 0.868 0.025 0.575 0.043 0.486 0.065
0.872 0.028 0.605 0.030 0.518 0.037
0.879 0.027 0.600 0.051 0.478 0.048
RDPD 0.883 0.016 0.609 0.052 0.529 0.051
Ankle Direct 0.811 0.035 0.513 0.065 0.405 0.080
Distill 0.901 0.015 0.621 0.044 0.492 0.070
0.889 0.021 0.581 0.071 0.443 0.095
0.904 0.019 0.629 0.041 0.473 0.069
RDPD 0.910 0.014 0.639 0.030 0.511 0.033
Table 2: Performance comparison on PAMAP2 dataset. The task is multi-class classification (12 classes). All contains 52 channels, Wrist contains 17 channels signals of 1 IMU over the wrist on the dominant arm, Chest contains 17 channels signals of 1 IMU on the chest, Ankle contains 17 channels signals of 1 IMU on the dominant side’s ankle.
Data Method ROC-AUC PR-AUC macro-F1
All Teacher 0.737 0.035 0.293 0.018 0.288 0.028
Lead I Direct 0.701 0.023 0.279 0.017 0.164 0.020
Distill 0.676 0.045 0.282 0.022 0.217 0.016
0.677 0.036 0.255 0.029 0.139 0.027
0.707 0.073 0.282 0.044 0.218 0.024
RDPD 0.706 0.075 0.293 0.025 0.218 0.019
Table 3: Performance comparison on PTBDB dataset. The task is multi-class classification (6 classes). All contains 15 channels of ECG signals. Lead I contains single channel Lead I ECG signal, which is usually generated by mobile devices.
Data Method ROC-AUC PR-AUC macro-F1
All Teacher 0.696 0.011 0.281 0.009 0.256 0.012
BP Direct 0.610 0.016 0.204 0.011 0.149 0.013
Distill 0.611 0.013 0.206 0.007 0.150 0.005
0.607 0.012 0.203 0.003 0.148 0.003
0.613 0.020 0.205 0.009 0.147 0.007
RDPD 0.614 0.018 0.207 0.010 0.150 0.006
HR Direct 0.556 0.019 0.176 0.013 0.089 0.042
Distill 0.564 0.021 0.175 0.012 0.109 0.030
0.566 0.010 0.178 0.004 0.132 0.005
0.571 0.011 0.176 0.008 0.123 0.016
RDPD 0.581 0.014 0.182 0.004 0.130 0.010
RR Direct 0.570 0.019 0.176 0.012 0.109 0.039
Distill 0.614 0.023 0.201 0.009 0.162 0.015
0.611 0.014 0.202 0.007 0.160 0.016
0.614 0.017 0.205 0.006 0.169 0.010
RDPD 0.619 0.022 0.207 0.008 0.169 0.007
Table 4: Performance comparison on MIMIC-III dataset. The task is multi-class classification (8 classes). All contains 6 channels of patient vital signs. BP contains blood pressure systolic and blood pressure diastolic, which is usually monitors by house sphygmomanometer. HR is heart rate, RR is respiration rate.

b. Reduction of Model Complexity

We analyzed model complexity by comparing model size of the teacher model and RDPD. Table 5 shows that the model size of RDPD is only of the model size of teacher model. According to previous results, other methods have comparable model size with our approach, but their performance are worse. In real world applications such as mobile health or ICU real-time modeling, it is very important that RDPD can achieve both lighter in model and better in performance.

Dataset Model Number of Parameters
PAMAP2 Teacher 118.3k
RDPD 8.2k
PTBDB Teacher 335.0k
RDPD 19.8k
MIMIC-III Teacher 60.2k
RDPD 4.0k
Table 5: Model Complexity Comparison

c. Evaluation against Size of Rich Data

We evaluated the dependency of size of rich data. We used the same validation and test data, but scaled down the size of rich data in training. Fig. 3 shows RDPD outperformed baselines even we have few rich data, and would perform better as we got more rich data. This demonstrated the efficacy of RDPD in extracting useful knowledge from rich data and teaching student even under rich-data insufficiency.

Figure 3: Performance comparison of training data size using PAMAP2 dataset.

d. Evaluation for the Setting of Low Quality Poor-data

Here we also assess how much benefit the multi-modality data can bring us from low quality poor-data. We performs experiments by adding different level of noise to the entire single modality. The approach of adding noise is

(16)

where is the original data and is the noise interfered data, is element-wise add, is the parameter to control the noise amplitude. From Fig. 4, we can see with the increasing amplitude of noise, the performance of both Direct and RDPD decrease. However, RDPD still works better than Direct due to knowledge transfer from Teacher model.

Figure 4: Performance comparison of noise amplitude using PAMAP2 dataset.

e. Effect of Temperature

Temperature

is one critical hyperparameter that controls the degree of smoothness for soft labels. The larger of smoothing parameter

, the smoother of probability distribution over classes. To determine a proper value for , our decision was based on the ROC-AUC and PR-AUC in validation set. In Figure 5 we plot the effects of hyperparameter . It is easy to see that both ROC-AUC and PR-AUC slightly increase along the of from to , then start to drop as becomes larger. The reason is that larger leads to a softer probabilities distribution, thus the model would be less discriminative and consequently yield lower ROC-AUC and PR-AUC. In our experiments, we set on PAMAP2 dataset as a proper choice.

Figure 5: Performance comparison of temperature in RDPD on PAMAP2 dataset. (Left) ROC-AUC, (Right) PR-AUC.

Case Study

To further analyze the effects and reasons that how RDPD improves performance. We choose several subjects and compared their attention weights and prediction results learned by different methods.

Figure 6: An example on PAMAP2 data.

Fig. 6 shows an example from PAMAP2. Both RDPD and Teacher correctly predict the subject is descending stairs, while Direct predicts the subject is doing Nordic walking and KD gives prediction of rope jumping. Fig. 6 left upper plots the z-axis acceleration from chest sensors, which shows the vertical acceleration of the whole body. Although these activities are similar, the change between walking on the floor and stairs distinguish the data from descending stairs with Nordic walking and rope jumping. Teacher model provide correct prediction by looking more channels, thus gives more attention in the vertical acceleration. RDPD also predicts correctly since it imitates the attentions from Teacher as shown in Figure 6 left bottom.

Figure 7: An example of PTBDB data and model predictions. (Left Upper) Original Lead I ECG data. (Left Bottom) Attention weights of different methods. (Right) Predicted probabilities distribution of different methods.

An example of PTBDB data and model predictions are shown in Figure 7. RDPD and Teacher give correct predictions while Direct and KD are wrong. The reason of RDPD better than KD and Direct comes from two aspects. One the one hand, RDPD and KD imitate Teacher’s soft label so that they correct wrong predictions like Direct in some extent. On the other hand, RDPD also imitate Teacher’s attention weights (shown in Left Bottom, purple is Teacher and red is RDPD), so that RDPD gives more confident predictions of Myocardial Infarction than KD, and thus further correct final predictions given by KD. Besides, we also found that RDPD gives more confident predictions of Myocardial Infarction than teacher model. The reason is that RDPD also considers combined label, and it leverages the soft label from Teacher model and hard label thus would be more confident when soft label and hard label are consistently right.

Figure 8: An example of MIMIC-III data and model predictions. (Left Upper) Original heart rate data. (Left Bottom) Attention weights of different methods. (Right) Predicted probabilities distribution of different methods.

An example of MIMIC-III data and model predictions are shown in Figure 8. RDPD and Teacher give correct predictions while Direct and KD are wrong. When looking at the attention weights (shown in left bottom), the Direct shows average weights on all part, while Teacher, KD and RDPD emphasize at the beginning. Moreover, in the first and the second part of attention weights, RDPD are in the middle of KD and Teacher, which reveal that RDPD successfully learns Teacher’s attentions that changes from KD to Teacher. In the right part, we can see that the probabilities of correct prediction increases from Direct to KD, and further increases in RDPD.

Related Work

Domain adaptation

Domain adaptation techniques address the problem of learning models on some source data distribution that generalize to a different target distribution. Deep learning based domain adaptation methods have focused mainly on learning domain-invariant representations. For example,  [10] and  [5] stacked Denoising Auto-encoders (SDA) and marginalized SDA respectively to extract meaningful representations.  [9] added a Gradient Reversal Layer that hinders the model’s ability to discriminate between domains. Moreover,  [27] transferred the source examples to the target domain and vice versa using BiTransferring Deep Neural Networks, while  [3] propose Domain Separation Networks. However they are need to be trained jointly on source and target domain data and are therefore unappealing to the settings where both data are available.

Knowledge Distillation

Knowledge Distillation  [11] or mimic learning  [1] are a family of approaches that aim to transfer the predictive power from more accurate deep models (“teacher model”) to smaller models (“student model”) like shallow neural networks  [11]

, soft decision tree  

[8] via training smaller models on soft labels learned from deep models. It has been widely used in model compression  [23]

, omni-supervised learning  

[20]

, fast optimization, network minimization and transfer learning  

[25]. Extensions of knowledge distillation unifies distillation and privileged information into generalized distillation framework to learn from multiple machines and data representations  [18]. The performance of distilled shallow neural networks are often better than models that are directly built on training data. The biggest difference between our approach and knowledge distillation is that, knowledge distillation focus on transfer powerful predictions ability of teacher to student model, while our approach is designed to transfer both behaviors and predictions from rich data modalities to poor data (a single modality).

Attention Transfer

Attention mechanism  [2] was proposed to improve performance of machine translation by paying more attention on relevant parts of the data. Recently, there are several works studying attention transfer  [26, 12]

to enhance shallow neural networks. The goal was achieved by learning similar attention models in smaller neural networks, then defining attention as gradient with respect to the input  

[26] or use regularization term  [12] to make two models have similar attention weights. Attention transfer has been used in video recognition from web images  [15], cross-domain sentiment classification  [16] and so on. The biggest difference between our approach and attention transfer is that attention transfer is used for model compression on one dataset, while our approach is used to transfer across datasets of very different data modalities.

Conclusion

In this paper we proposed to leverage the power of rich data to improve the learning from poor data with RDPD. RDPD learns end-to-end for the student model built on poor data to imitate the behavior (attention imitation) and performance (target imitation) of teacher model by jointly optimizing the combined loss of attention imitation and target imitation. We evaluated RDPD across multiple datasets and demonstrated its promising utility and efficacy. Future extension of RDPD includes considering the following settings: modeling static meta information as one modality, learning from less labels, etc.

References