Distilled Wasserstein Learning for Word Embedding and Topic Modeling

09/12/2018 ∙ by Hongteng Xu, et al. ∙ 0

We propose a novel Wasserstein method with a distillation mechanism, yielding joint learning of word embeddings and topics. The proposed method is based on the fact that the Euclidean distance between word embeddings may be employed as the underlying distance in the Wasserstein topic model. The word distributions of topics, their optimal transports to the word distributions of documents, and the embeddings of words are learned in a unified framework. When learning the topic model, we leverage a distilled underlying distance matrix to update the topic distributions and smoothly calculate the corresponding optimal transports. Such a strategy provides the updating of word embeddings with robust guidance, improving the algorithmic convergence. As an application, we focus on patient admission records, in which the proposed method embeds the codes of diseases and procedures and learns the topics of admissions, obtaining superior performance on clinically-meaningful disease network construction, mortality prediction as a function of admission codes, and procedure recommendation.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Word embedding and topic modeling play important roles in natural language processing (NLP), as well as other applications with textual and sequential data. Many modern embedding methods 

mikolov2013efficient ; pennington2014glove ; liu2015topical assume that words can be represented and predicted by contextual (surrounding) words. Accordingly, the word embeddings are learned to inherit those relationships. Topic modeling methods blei2003latent , in contrast, typically represent documents by the distribution of words, or other “bag-of-words” techniques gerard1983introduction ; joachims2002learning , ignoring the order and semantic relationships among words. The distinction between how the word order is (or is not) accounted for when learning topics and word embeddings manifests a potential methodological gap or mismatch.

This gap is important when considering clinical-admission analysis, the motivating application of this paper. Patient admissions in hospitals are recorded by the code of international classification of diseases (ICD). For each admission, one may observe a sequence of ICD codes corresponding to certain kinds of diseases and procedures, and each code is treated as a “word.” To reveal the characteristics of the admissions and relationships between different diseases/procedures, we seek to model the “topics” of admissions and also learn an embedding for each ICD code. However, while we want embeddings of similar diseases/procedures to be nearby in the embedding space, learning the embedding vectors based on surrounding ICD codes for a given patient admission is less relevant, as there is often a diversity in the observed codes for a given admission, and the code order may hold less meaning. Take the MIMIC-III dataset 

johnson2016mimic as an example. The ICD codes in each patient’s admission are ranked according to a manually-defined priority, and the adjacent codes are often not clinically-correlated with each other. Therefore, we desire a model that jointly learns topics and word embeddings, and that for both does not consider the word (ICD code) order. Interestingly, even in the context of traditional NLP tasks, it has been recognized recently that effective word embeddings may be learned without considering word order shen2018swen , although that work didn’t consider topic modeling or our motivating application.

Although some works have applied word embeddings to represent ICD codes and related clinical data choi2016multi ; huang2018empirical , they ignore the fact that the clinical relationships among the diseases/procedures in an admission may not be approximated well by their neighboring relationships in the sequential record. Most existing works either treat word embeddings as auxiliary features for learning topic models das2015gaussian or use topics as the labels for supervised embedding liu2015topical . Prior attempts at learning topics and word embeddings jointly shi2017jointly have fallen short from the perspective of these two empirical strategies.

Figure 1: Consider two admissions with mild and severe diabetes, which are represented by two distributions of diseases (associated with ICD codes) in red and orange, respectively. They are two dots in the Wasserstein ambient space, corresponding to two weighted barycenters of Wasserstein topics (the color stars). The optimal transport matrix between these two admissions is built on the distance between disease embeddings in the Euclidean latent space. The large value in the matrix (the dark blue elements) indicates that it is easy to transfer diabetes to its complication like nephropathy, whose embedding is a short distance away (short blue arrows).

We seek to fill the aforementioned gap, while applying the proposed methodology to clinical-admission analysis. As shown in Fig. 1, the proposed method is based on a Wasserstein-distance model, in which () the Euclidean distance between ICD code embeddings works as the underlying distance (also referred to as the cost) of the Wasserstein distance between the distributions of the codes corresponding to different admissions kusner2015word ; () the topics are “vertices” of a geometry in the Wasserstein space and the admissions are the “barycenters” of the geometry with different weights schmitz2017wasserstein . When learning this model, both the embeddings and the topics are inferred jointly. A novel learning strategy based on the idea of model distillation hinton2015distilling ; lopez2015unifying is proposed, improving the convergence and the performance of the learning algorithm.

The proposed method unifies word embedding and topic modeling in a framework of Wasserstein learning. Based on this model, we can calculate the optimal transport between different admissions and explain the transport by the distance of ICD code embeddings. Accordingly, the admissions of patients become more interpretable and predictable. Experimental results show that our approach is superior to previous state-of-the-art methods in various tasks, including predicting admission type, mortality of a given admission, and procedure recommendation.

2 A Wasserstein Topic Model Based on Euclidean Word Embeddings

Assume that we have documents and a corpus with words, , respectively, admission records and the dictionary of ICD codes. These documents can be represented by , where , , is the distribution of the words in the -th document, and is an -dimensional simplex. These distributions can be represented by some basis (, topics), denoted as , where is the -th base distribution. The word embeddings can be formulated as , where is the embedding of the -th word, , is obtained by a model, , with parameters and predefined representation of the word (, may be a one-hot vector for each word). The distance between two word embeddings is denoted , and generally it is assumed to be Euclidean. These distances can be formulated as a parametric distance matrix .

Denote the space of the word distributions as the ambient space and that of their embeddings as the latent space. We aim to model and learn the topics in the ambient space and the embeddings in the latent space in a unified framework. We show that recent developments in the methods of Wasserstein learning provide an attractive solution to achieve this aim.

2.1 Revisiting topic models from a geometric viewpoint

Traditional topic models blei2003latent often decompose the distribution of words conditioned on the observed document into two factors: the distribution of words conditioned on a certain topic, and the distribution of topics conditioned on the document. Mathematically, it corresponds to a low-rank factorization of , , , where contains the word distributions of different topics and , , contains the topic distributions of different documents. Given and , can be equivalently written as

(1)

where

is the probability of topic

given document . From a geometric viewpoint, in (1) can be viewed as vertices of a geometry, whose “weights” are . Then, is the weighted barycenter of the geometry in the Euclidean space.

Following this viewpoint, we can extend (1) to another metric space, ,

(2)

where is the barycenter of the geometry, with vertices and weights in the space with metric .

2.2 Wasserstein topic model

When the distance in (2) is the Wasserstein distance, we obtain a Wasserstein topic model, which has a natural and explicit connection with word embeddings. Mathematically, let be an arbitrary space with metric and be the set of Borel probability measures on , respectively.

Definition 2.1.

For and probability measures and in , their -order Wasserstein distance villani2008optimal is , where is the set of all probability measures on with and as marginals.

Definition 2.2.

The -order weighted Fréchet mean in the Wasserstein space (or called Wasserstein barycenter) agueh2011barycenters of measures in is , where decides the weights of the measures.

When is a discrete state space, , , the Wasserstein distance is also called the optimal transport (OT) distance schmitz2017wasserstein . More specifically, the Wasserstein distance with corresponds to the solution to the discretized Monge-Kantorovich problem:

(3)

where and are two distributions of the discrete states and is the underlying distance matrix, whose element measures the distance between different states. , and represents the matrix trace. The matrix is called the optimal transport matrix when the minimum in (3) is achieved.

Applying the discrete Wasserstein distance in (3) to (2), we obtain our Wasserstein topic model, ,

(4)

In this model, the discrete states correspond to the words in the corpus and the distance between different words can be calculated by the Euclidean distance between their embeddings.

In this manner, we establish the connection between the word embeddings and the topic model: the distance between different topics (and different documents) is achieved by the optimal transport between their word distributions built on the embedding-based underlying distance. For arbitrary two word embeddings, the more similar they are, the smaller underlying distance we have, and more easily we can achieve transfer between them. In the learning phase (as shown in the following section), we can learn the embeddings and the topic model jointly. This model is especially suitable for clinical admission analysis. As discussed above, we not only care about the clustering structure of admissions (the relative proportion, by which each topic is manifested in an admission), but also want to know the mechanism or the tendency of their transfers in the level of disease. As shown in Fig. 1, using our model, we can calculate the Wasserstein distance between different admissions in the level of disease and obtain the optimal transport from one admission to another explicitly. The hierarchical architecture of our model helps represent each admission by its topics, which are the typical diseases/procedures (ICD codes) appearing in a class of admissions.

3 Wasserstein Learning with Model Distillation

Given the word-document matrix and a predefined number of topics , we wish to jointly learn the basis , the weight matrix , and the model of word embeddings. This learning problem can be formulated as

(5)

Here, and the element

. The loss function

measures the difference between

and its estimation

. We can solve this problem based on the idea of alternating optimization. In each iteration we first learn the basis and the weights given the current parameters . Then, we learn the new parameters based on updated and .

3.1 Updating word embeddings to enhance the clustering structure

Suppose that we have obtained updated and . Given current , we denote the optimal transport between document and topic as . Accordingly, the Wasserstein distance between and is . Recall from the topic model in (4) that each document is represented as the weighted barycenter of in the Wasserstein space, and the weights represent the closeness between the barycenter and different bases (topics). To enhance the clustering structure of the documents, we update by minimizing the Wasserstein distance between the documents and their closest topics. Consequently, the documents belonging to different clusters would be far away from each other. The corresponding objective function is

(6)

where is the optimal transport between and its closest base . The aggregation of these transports is given by , and are the word embeddings. Considering the symmetry of , we can replace in (6) with . The objective function can be further written as , where is the Laplacian matrix. To avoid trivial solutions like , we add a smoothness regularizer and update by optimizing the following problem:

(7)

where is current parameters and controls the significance of the regularizer. Similar to Laplacian Eigenmaps belkin2003laplacian , the aggregated optimal transport works as the similarity measurement between proposed embeddings. However, instead of requiring the solution of (7

) to be the eigenvectors of

, we enhance the stability of updating by ensuring that the new is close to the current one.

3.2 Updating topic models based on the distilled underlying distance

Given updated word embeddings and the corresponding underlying distance , we wish to further update the basis and the weights . The problem is formulated as a Wasserstein dictionary-learning problem, as proposed in schmitz2017wasserstein . Following the same strategy as schmitz2017wasserstein , we rewrite and as

(8)

where and are new parameters. Based on (8), the normalization of and is met naturally, and we can reformulate (5) to an unconstrained optimization problem, ,

(9)

Different from schmitz2017wasserstein , we introduce a model distillation method to improve the convergence of our model. The key idea is that the model with the current underlying distance works as a “teacher,” while the proposed model with new basis and weights is regarded as a “student.” Through , the teacher provides the student with guidance for its updating. We find that if we use the current underlying distance to calculate basis and weights

, we will encounter a serious “vanishing gradient” problem when solving (

7) in the next iteration. Because in (6) has been optimal under the current underlying distance and new and , it is difficult to further update .

Inspired by recent model distillation methods in hinton2015distilling ; lopez2015unifying ; pereyra2017regularizing , we use a smoothed underlying distance matrix to solve the “vanishing gradient” problem when updating and . In particular, the in (9) is replaced by a Sinkhorn distance with the smoothed underlying distance, , , where , , is an element-wise power function of a matrix. The Sinkhorn distance is defined as

(10)

where calculates element-wise logarithm of a matrix. The parameter

works as the reciprocal of the “temperature” in the smoothed softmax layer in the original distillation method 

hinton2015distilling ; lopez2015unifying .

The principle of our distilled learning method is that when updating and , the smoothed underlying distance is used to provide “weak” guidance. Consequently, the student (, the proposed new model with updated and ) will not completely rely on the information from the teacher (, the underlying distance obtained in a previous iteration), and will tend to explore new basis and weights. In summary, the optimization problem for learning the Wasserstein topic model is

(11)

which can be solved under the same algorithmic framework as that in schmitz2017wasserstein .

Our algorithm is shown in Algorithm 1. The details of the algorithm and the influence of our distilled learning strategy on the convergence of the algorithm are given in the Supplementary Material. Note that our method is compatible with existing techniques, which can work as a fine-tuning method when the underlying distance is initialized by predefined embeddings. When the topic of each document is given, in (6) is predefined and the proposed method can work in a supervised way.

1:  Input: The distributions of words for documents . The distillation parameter

. The number of epochs

. Batch size . The weight in Sinkhon distance . The weight in (7). The learning rate .
2:  Output: The parameters , basis , and weights .
3:  Initialize , and calculate and by (8).
4:  For
5:   For Each batch of documents
6:    Calculate the Sinkhorn gradient with distillation: and .
7:    , .
8:    Calculate , and the gradient of (7) , then update .
Algorithm 1 Distilled Wasserstein Learning (DWL) for Joint Word Embedding and Topic Modeling

4 Related Work

Word embedding, topic modeling, and their application to clinical data Traditional topic models, like latent Dirichlet allocation (LDA) blei2003latent and its variants, rely on the “bag-of-words” representation of documents. Word embedding mikolov2013efficient provides another choice, which represents documents as the fusion of the embeddings le2014distributed . Recently, many new word embedding techniques have been proposed, , the Glove in pennington2014glove and the linear ensemble embedding in muromagi2017linear , which achieve encouraging performance on word and document representation. Some works try to combine word embedding and topic modeling. As discussed above, they either use word embeddings as features for topic models shi2017jointly ; das2015gaussian or regard topics as labels when learning embeddings wang2017topic ; liu2015topical . A unified framework for learning topics and word embeddings was still absent prior to this paper.

Focusing on clinical data analysis, word embedding and topic modeling have been applied to many tasks. Considering ICD code assignment as an example, many methods have been proposed to estimate the ICD codes based on clinical records shi2017towards ; baumel2017multi ; mullenbach2018explainable ; huang2018empirical , aiming to accelerate diagnoses. Other tasks, like clustering clinical data and the prediction of treatments, can also be achieved by NLP techniques bajor2018embedding ; harutyunyan2017multitask ; choi2016multi .

Wasserstein learning and its application in NLP The Wasserstein distance has been proven useful in distribution estimation boissard2015distribution , alignment zemel2017fr and clustering agueh2011barycenters ; ye2017fast ; cuturi2014fast

, avoiding over-smoothed intermediate interpolation results. It can also be used as loss function when learning generative models 

courty2017learning ; arjovsky2017wasserstein . The main bottleneck of the application of Wasserstein learning is its high computational complexity. This problem has been greatly eased since Sinkhorn distance was proposed in cuturi2013sinkhorn . Based on Sinkhorn distance, we can apply iterative Bregman projection benamou2015iterative to approximate Wasserstein distance, and achieve a near-linear time complexity altschuler2017near . Many more complicated models have been proposed based on Sinkhorn distance genevay2017sinkhorn ; schmitz2017wasserstein . Focusing on NLP tasks, the methods in kusner2015word ; huang2016supervised use the same framework as ours, computing underlying distances based on word embeddings and measuring the distance between documents in the Wasserstein space. However, the work in kusner2015word does not update the pretrained embeddings, while the model in huang2016supervised does not have a hierarchical architecture for topic modeling.

Model distillation

As a kind of transfer learning techniques, model distillation was originally proposed to learn a simple model (student) under the guidance of a complicated model (teacher) 

hinton2015distilling . When learning the target-distilled model, a regularizer based on the smoothed outputs of the complicated model is imposed. Essentially, the distilled complicated model provides the target model with some privileged information lopez2015unifying . This idea has been widely used in many applications, , textual data modeling inan2016tying , healthcare data analysis che2015distilling , and image classification gupta2016cross

. Besides transfer learning, the idea of model distillation has been extended to control the learning process of neural networks 

pereyra2017regularizing ; rusu2016progressive ; wang2016learning . To the best of our knowledge, our work is the first attempt to combine model distillation with Wasserstein learning.

5 Experiments

To demonstrate the feasibility and the superiority of our distilled Wasserstein learning (DWL) method, we apply it to analysis of admission records of patients, and compare it with state-of-the-art methods. We consider a subset of the MIMIC-III dataset johnson2016mimic , containing patient admissions, corresponding to diseases and procedures, and each admission is represented as a sequence of ICD codes of the diseases and the procedures. Using different methods, we learn the embeddings of the ICD codes and the topics of the admissions and test them on three tasks: mortality prediction, admission-type prediction, and procedure recommendation. For all the methods, we use of the admissions for training, for validation, and the remaining for testing in each task. For our method, the embeddings are obtained by the linear projection of one-hot representations of the ICD codes, which is similar to the Word2Vec mikolov2013efficient and the Doc2Vec le2014distributed . For our method, the loss function

is squared loss. The hyperparameters of our method are set via cross validation: the batch size

, , , the number of topics , the embedding dimension , and the learning rate . The number of epochs is set to be when the embeddings are initialized by Word2Vec, and when training from scratch. The distillation parameter is empirically, whose influence on learning result is shown in the Supplementary Material.

5.1 Admission classification and procedure recommendation

The admissions of patients often have a clustering structure. According to the seriousness of the admissions, they are categorized into four classes in the MIMIC-III dataset: elective, emergency, urgent and newborn

. Additionally, diseases and procedures may lead to mortality, and the admissions can be clustered based on whether the patients die or not during their admissions. Even if learned in a unsupervised way, the proposed embeddings should reflect the clustering structure of the admissions to some degree. We test our DWL method on the prediction of admission type and mortality. For the admissions, we can either represent them by the distributions of the codes and calculate the Wasserstein distance between them, or represent them by the average pooling of the code embeddings and calculate the Euclidean distance between them. A simple KNN classifier can be applied under these two metrics, and we consider

and . We compare the proposed method with the following baselines: () bag-of-words-based methods like TF-IDF gerard1983introduction and LDA blei2003latent ; () word/document embedding methods like Word2Vec mikolov2013efficient , Glove pennington2014glove , and Doc2Vec le2014distributed ; and () the Wasserstein-distance-based method in kusner2015word . We tested various methods in trials. In each trial, we trained different models on a subset of training admissions and tested them on the same testing set, and calculated the averaged results and their confidential intervals.

The classification accuracy for various methods are shown in Table 1. Our DWL method is superior to its competitors on classification accuracy. Besides this encouraging result, we also observe two interesting and important phenomena. First, for our DWL method the model trained from scratch has comparable performance to that fine-tuned from Word2Vec’s embeddings, which means that our method is robust to initialization when exploring clustering structure of admissions. Second, compared with measuring Wasserstein distance between documents, representing the documents by the average pooling of embeddings and measuring their Euclidean distance obtains comparable results. Considering the fact that measuring Euclidean distance has much lower complexity than measuring Wasserstein distance, this phenomenon implies that although our DWL method is time-consuming in the training phase, the trained models can be easily deployed for large-scale data in the testing phase.

Word Feature Doc. Feature Metric Dim.  Mortality  Adm. Type
1-NN 5-NN 1-NN 5-NN
 — TF-IDF gerard1983introduction Euclidean 81 69.980.05 75.320.04 82.270.03 88.280.02
 — LDA blei2003latent 8 66.030.06 69.050.06 81.410.04 86.570.04
 Word2Vec mikolov2013efficient Doc2Vec le2014distributed 50 57.980.08 59.800.08 70.570.08 79.940.07
 Word2Vec mikolov2013efficient AvePooling 50 70.420.05 75.210.04 84.880.07 89.160.06
 Glove pennington2014glove AvePooling 50 66.940.06 73.210.04 81.910.05 88.210.05
 DWL (Scratch) AvePooling 50 71.010.12 74.740.11 84.540.13 89.490.12
 DWL (Finetune) AvePooling 50 71.520.07 75.440.07 85.540.09 89.280.09
 Word2Vec mikolov2013efficient Topic weight schmitz2017wasserstein Euclidean 8 70.310.04 74.890.04 83.630.05 89.250.04
 DWL (Scratch) 70.450.08 74.880.07 83.820.12 88.800.12
 DWL (Finetune) 70.880.07 75.670.07 84.260.09 89.130.08
 Word2Vec mikolov2013efficient Word distribution 81 70.610.04 75.920.04 84.080.05 89.060.05
 Glove pennington2014glove Wasserstein 70.640.06 75.970.05 83.920.08 89.170.07
 DWL (Scratch) kusner2015word 71.010.10 75.880.09 84.230.12 89.330.11
 DWL (Finetune) 70.650.07 76.000.06 84.350.08 89.610.07
Table 1: Admission classification accuracy (%) for various methods.
Method Top-1 (%) Top-3 (%) Top-5 (%)
P R F1 P R F1 P R F1
Word2Vec mikolov2013efficient 39.95 13.27 18.25 31.70 33.46 29.30 28.89 46.98 32.59
Glove pennington2014glove 32.66 13.01 17.22 29.45 30.99 27.41 27.93 44.79 31.47
DWL (Scratch) 37.89 12.42 17.16 30.14 29.78 27.14 27.39 43.81 30.81
DWL (Finetune) 40.00 13.76 18.71 31.88 33.71 29.58 30.59 48.56 34.28
Table 2: Top- procedure recommendation results for various methods.

The third task is recommending procedures according to the diseases in the admissions. In our framework, this task can be solved by establishing a bipartite graph between diseases and procedures based on the Euclidean distance between their embeddings. The proposed embeddings should reflect the clinical relationships between procedures and diseases, such that the procedures are assigned to the diseases with short distance. For the -th admission, we may recommend a list of procedures with length , denoted as , based on its diseases and evaluate recommendation results based on the ground truth list of procedures, denoted as . In particular, given , we calculate the top- precision, recall and F1-score as follows: , , . Table 2 shows the performance of various methods with . We find that although our DWL method is not as good as the Word2Vec when the model is trained from scratch, which may be caused by the much fewer epochs we executed, it indeed outperforms other methods when the model is fine-tuned from Word2Vec.

5.2 Rationality Analysis

To verify the rationality of our learning result, in Fig. 2 we visualize the KNN graph of diseases and procedures. We can find that the diseases in Fig. 2(a) have obvious clustering structure while the procedures are dispersed according to their connections with matched diseases. Furthermore, the three typical subgraphs in Fig. 2 can be interpreted from a clinical viewpoint. Figure 2(b) clusters cardiovascular diseases like hypotension (d4589, d45829) and hyperosmolality (d2762) with their common procedure, , diagnostic ultrasound of heart (p8872). Figure 2(c) clusters coronary artery bypass (p3615) with typical postoperative responses like hyperpotassemia (d2767), cardiac complications (d9971) and congestive heart failure (d4280). Figure 2(d) clusters chronic pulmonary heart diseases (d4168) with its common procedures like cardiac catheterization (p3772) and abdominal drainage (p5491) and the procedures are connected with potential complications like septic shock (d78552). The rationality of our learning result can also be demonstrated by the topics shown in Table 3. According to the top- ICD codes, some topics have obvious clinical interpretations. Specifically, topic 1 is about kidney disease and its complications and procedures; topic 2 and 5 are about serious cardiovascular diseases; topic 4 is about diabetes and its cardiovascular complications and procedures; topic 6 is about the diseases and the procedures of neonatal. We show the map between ICD codes and corresponding diseases/procedures in the Supplementary Material.

(a) Full graph
(b) Enlarged part 1
(c) Enlarged part 2
(d) Enlarged part 3
Figure 2: (a) The KNN graph of diseases and procedures with . Its enlarged version is in the Supplementary Material. The ICD codes related to diseases are with a prefix “d”, whose nodes are blue, while those related to procedures are with a prefix “p”, whose nodes are orange. (b-d) Three enlarged subgraphs corresponding to the red frames in (a). In each subfigure, the nodes/dots in blue are diseases while the nodes/dots in orange are procedures.
 Topic 1 Topic 2 Topic 3 Topic 4 Topic 5 Topic 6 Topic 7 Topic 8
 d5859 d4241 d311 p8856 d2449 d7742 p9904 d311
Chronic kidney disease Aortic valve disorders Mycobacteria Coronary arteriography Hypothyroidism Neonatal jaundice Cell transfusion Mycobacteria
 d2859 p3891 dV3001 d41071 d2749 p9672 d5119 d5119
Anemia Arterial catheterization Single liveborn Subendocardial infarction Gout Ventilation Pleural effusion Pleural effusion
 p8872 d9971 d5849 d2851 d41401 p9907 p331 d42731
Heart ultrasound Cardiac complications Kidney failure Posthemorrhagic anemia Atherosclerosis Serum transfusion Incision of lung Atrial fibrillation
Table 3: Top- ICD codes in each topic associated with the corresponding diseases/procedures.

6 Conclusion and Future Work

We have proposed a novel method to jointly learn the Euclidean word embeddings and a Wasserstein topic model in a unified framework. An alternating optimization method was applied to iteratively update topics, their weights, and the embeddings of words. We introduced a simple but effective model distillation method to improve the performance of the learning algorithm. Testing on clinical admission records, our method shows the superiority over other competitive models for various tasks. Currently, the proposed learning method shows a potential for more-traditional textual data analysis (documents), but its computational complexity is still too high for large-scale document applications (because the vocabulary for real documents is typically much larger than the number of ICD codes considered here in the motivating hospital-admissions application). In the future, we plan to further accelerate the learning method, , by replacing the Sinkhorn-based updating precedure with its variants like the Greenkhorn-based updating method altschuler2017near .

7 Acknowledgments

This research was supported in part by DARPA, DOE, NIH, ONR and NSF. Morgan A. Schmitz kindly helped us by sharing his Wasserstein dictionary learning code. We also thank Prof. Hongyuan Zha at Georgia Institute of Technology for helpful discussions.

References

Appendix

The derivation of Sinkhorn gradient

The key part of our learning algorithm is calculating Sinkhorn gradient given the distilled underlying distance matrix . Same to the method in schmitz2017wasserstein , we use the following algorithm to calculate and for each document .

1:  Input: Arbitrary document . Underlying distance . Distillation parameter . The number of inner iteration . The weight in Sinkhon distance . Current basis and weights .
2:  Output: and .
3:  Calculate .
4:  Forward loop:
5:  Initialize for .
6:  for  do
7:      for .
8:     .
9:     .
10:  end for
11:  Backward loop for weights:
12:  Initialize , , .
13:  for  do
14:      for .
15:      for .
16:     
17:  end for
18:  Backward loop for basis:
19:  Initialize , .
20:  for  do
21:     .
22:     .
23:     .
24:     .
25:  end for
26:   and .
Algorithm 2 Computation of Sinkhorn gradient

Here, is element-wise multiplication, is element-wise division, is element-wise square, and is element-wise logarithm. More details of the algorithm can be found in schmitz2017wasserstein .

Influence of distillation parameters

The distillation parameter has significant influence on the convergence and the performance of our learning algorithm. We visualize the convergence rate of our DWL method with respect to different ’s in the task of admission type prediction. In Fig. 3, we can find that when , which means that the model is learned without distillation, the increase of training accuracy is very slow because of the gradient vanishment problem. On the contrary, when , which means that we use model distillation heavily in the training phase and the “student” leverages little information from “teacher”, the training accuracy increases rapidly but converges to an unsatisfying level. This is because the distilled underlying distance is over-smoothed, which cannot provide sufficient guidance to further update basis and weights. To achieve a trade-off between the convergence and the performance of our algorithm, finally we choose empirically according to the experimental results.

It should be noted that although we set the distillation parameter empirically, as hinton2015distilling ; lopez2015unifying did, we give a reasonable range: should be smaller than (to achieve distillation) and larger than (to avoid oversmoothness). We will study the setting of the parameter in our future work.

Figure 3: The convergence of our DWL method with respect to ’s in the task of admission type prediction.

Sentiment analysis on Twitter dataset

Besides the MIMIC-III dataset, we compared our method against the Wasserstein-distance based method kusner2015word

on sentiment analysis based on the Twitter dataset in that paper. Our method obtains comparable results, i.e.,

testing error, which is slightly lower than that in kusner2015word .

The enlarged graph of ICD codes

The Fig. 2(a) in the paper is enlarged and shown below for better visual effect. The map between ICD codes and diseases/procedures is attached as well.

Figure 4: The enlarged KNN graph of diseases and procedures with .
  ICD code Disease/Procedure
  d4019 Unspecified essential hypertension
  d41401 Coronary atherosclerosis of native coronary artery
  d4241 Aortic valve disorders
  dV4582 Percutaneous transluminal coronary angioplasty status
  d2724 Other and unspecified hyperlipidemia
  d486 Pneumonia, organism unspecified
  d99592 Severe sepsis
  d51881 Acute respiratory failure
  d5990 Urinary tract infection, site not specified
  d5849 Acute kidney failure, unspecified
  d78552 Septic shock
  d25000 Diabetes mellitus without mention of complication, type II or unspecified type
  d2449 Unspecified acquired hypothyroidism
  d41071 Subendocardial infarction, initial episode of care
  d4280 Congestive heart failure, unspecified
  d4168 Other chronic pulmonary heart diseases
  d412 Pneumococcus infection in conditions classified elsewhere and of unspecified site
  d2761 Hyposmolality and/or hyponatremia
  d2720 Pure hypercholesterolemia
  d2762 Acidosis
  d389 Unspecified septicemia
  d4589 Hypotension, unspecified
  d42731 Atrial fibrillation
  d2859 Anemia, unspecified
  d311 Cutaneous diseases due to other mycobacteria
  dV3001 Single liveborn, born in hospital, delivered by cesarean section
  dV053 Need for prophylactic vaccination and inoculation against viral hepatitis
  d4240 Mitral valve disorders
  dV3000 Single liveborn, born in hospital, delivered without mention of cesarean section
  d7742 Neonatal jaundice associated with preterm delivery
  d42789 Other specified cardiac dysrhythmias
  d5070 Pneumonitis due to inhalation of food or vomitus
  dV502 Routine or ritual circumcision
  d2760 Hyperosmolality and/or hypernatremia
  dV1582 Personal history of tobacco use
  d40390 Hypertensive chronic kidney disease, unspecified, with chronic kidney disease stage I through stage IV, or unspecified
  dV4581 Aortocoronary bypass status
  dV290 Observation for suspected infectious condition
  d5845 Acute kidney failure with lesion of tubular necrosis
  d2875 Thrombocytopenia, unspecified
  d2767 Hyperpotassemia
  d32723 Obstructive sleep apnea (adult)(pediatric)
  dV5861 Long-term (current) use of anticoagulants
  d2851 Acute posthemorrhagic anemia
  d53081 Esophageal reflux
  d496 Chronic airway obstruction, not elsewhere classified
  d40391 Hypertensive chronic kidney disease, unspecified, with chronic kidney disease stage V or end stage renal disease
  d9971 Gross hematuria
  d5119 Unspecified pleural effusion
  d2749 Gout, unspecified
  d5859 Chronic kidney disease, unspecified
  d49390 Asthma, unspecified type, unspecified
  d45829 Other iatrogenic hypotension
  d3051 Tobacco use disorder
  dV5867 Long-term (current) use of insulin
  d5180 Pulmonary collapse
  p9604 Insertion of endotracheal tube
  p9671 Continuous invasive mechanical ventilation for less than 96 consecutive hours
  p3615 Single internal mammary-coronary artery bypass
  p3961 Extracorporeal circulation auxiliary to open heart surgery
  p8872 Diagnostic ultrasound of heart
  p9904 Transfusion of packed cells
  p9907 Transfusion of other serum
  p9672 Continuous invasive mechanical ventilation for 96 consecutive hours or more
  p331 Spinal tap
  p3893 Venous catheterization, not elsewhere classified
  p966 Enteral infusion of concentrated nutritional substances
  p3995 Hemodialysis
  p9915 Parenteral infusion of concentrated nutritional substances
  p8856 Coronary arteriography using two catheters
  p9955 Prophylactic administration of vaccine against other diseases
  p3891 Arterial catheterization
  p9390 Non-invasive mechanical ventilation
  p9983 Other phototherapy
  p640 Circumcision
  p3722 Left heart cardiac catheterization
  p8853 Angiocardiography of left heart structures
  p3723 Combined right and left heart cardiac catheterization
  p5491 Percutaneous abdominal drainage
  p3324 Closed (endoscopic) biopsy of bronchus
  p4513 Other endoscopy of small intestine
Table 4: The map between ICD codes and diseases/procedures