Temporal Clustering with External Memory Network for Disease Progression Modeling

09/29/2021 ∙ by Zicong Zhang, et al. ∙ The Ohio State University 2

Disease progression modeling (DPM) involves using mathematical frameworks to quantitatively measure the severity of how certain disease progresses. DPM is useful in many ways such as predicting health state, categorizing disease stages, and assessing patients disease trajectory etc. Recently, with wider availability of electronic health records (EHR) and the broad application of data-driven machine learning method, DPM has attracted much attention yet remains two major challenges: (i) Due to the existence of irregularity, heterogeneity and long-term dependency in EHRs, most existing DPM methods might not be able to provide comprehensive patient representations. (ii) Lots of records in EHRs might be irrelevant to the target disease. Most existing models learn to automatically focus on the relevant information instead of explicitly capture the target-relevant events, which might make the learned model suboptimal. To address these two issues, we propose Temporal Clustering with External Memory Network (TC-EMNet) for DPM that groups patients with similar trajectories to form disease clusters/stages. TC-EMNet uses a variational autoencoder (VAE) to capture internal complexity from the input data and utilizes an external memory work to capture long term distance information, both of which are helpful for producing comprehensive patient states. Last but not least, k-means algorithm is adopted to cluster the extracted comprehensive patient states to capture disease progression. Experiments on two real-world datasets show that our model demonstrates competitive clustering performance against state-of-the-art methods and is able to identify clinically meaningful clusters. The visualization of the extracted patient states shows that the proposed model can generate better patient states than the baselines.



There are no comments yet.


page 1

page 3

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

With the recent development of deep learning and accumulation of electronic health record (EHR), also known as time-series data, there has been an increasing effort in clustering EHR data in order to discover meaningful patterns throughout longitudinal health information. Moreover, chronic diseases, such as Parkinson’s disease (PD) and Alzheimer’s disease (AD), can have various outcomes even with limited number of patients. Such diseases are heterogeneous in nature and often evolves at unique patterns that triggers distinct responses to therapeutic interventions based upon different conditions [16]. Thus, it has become crucial to develop a disease progression modeling (DPM) system to capture certain pattern information, provide early detection to critical situations, and yield clinically helpful information to improve the quality of care.

Traditionally, DPM or disease clustering/staging is developed by domain experts with extensive clinical experience, in which disease stages are defined separately and based solely on the values of one or a few biomarkers [8, 2]. Nevertheless, developing DPM system requires long-term observation and human labor, and the result is often based on known biomarkers and acknowledged covariants, which makes it difficult to develop DPM system for disease with limited knowledge on biomarkers that has not been well-studied. Recent years, rapid growth of data-driven machine learning method has motivated a great effort in developing DPM models. There are two main approach when it comes to DPM: 1) The problem is formed as a risk prediction task with label information and patient representation is extracted from the last layer of the model. [18, 10, 22, 31, 38]. 2) The problem is formed as a traditional unsupervised, patient clustering/subtyping problem where the model is trained to separate the patient into multiple groups [36, 9, 27]. Leveraging disease outcomes during the training process can prevent model from forming heterogeneous clusters. However, for certain disease, diagnosis labels are often unavailable at each patient visit due to limited knowledge to the disease. Moreover, deep learning models that are designed for supervised tasks may not perform well when training in an unsupervised fashion. Therefore, there is a need for developing a DPM framework that can handle both situations with respect to the availability of training labels. However, most developed deep learning models for disease progression modeling suffers from the following limitations:

  • Irregularity and heterogeneity: Many disease are heterogeneous in nature and EHR data has high internal complexity as well. Due to the complexity of effectively encoding various health conditions into patient representation, accurate DPM still remains a challenging problem.

  • Long-term Dependency: RNNs are long known to suffer from modeling long term information, since it tends to forget earlier information when input sequence is long. Disease progression modeling, especially for chronic disease, requires long-term observation of patient in order to provide comprehensive view for decision making.

  • Target Awareness: Most rnn-based methods derive patient representations directly from the hidden states of the model. Such approach neglect the contribution of target-relevant information. In fact, real-world clinical decisions made by doctors are often based upon past diagnosis as well.

To address these challenges, we propose Temporal Clustering with External Memory Network (TC-EMNet) for disease progression modeling via both supervised and unsupervised setting. At each time step, TC-EMNet takes EHR medical record as input and encodes the input feature using a recurrent neural network to get hidden representations. Then TC-EMNet samples from the hidden state to form a latent representation. Meanwhile, the hidden state is stored into a global-level memory network, which in turn outputs a memory representation based on current memory. The memory representation is then concatenated with the current latent representation to form the patient representation at current time step. When training label is available, the model also employs a patient-level memory work to process label information up to previous visits and outputs target-aware memory representation. We combine memory representations from global-level and patient-level memory network using a calibration process. TC-EMNet is trained with reconstruction objective under unsupervised setting and prediction objective under supervised setting.

In this paper, our contributions are four fold:

  • We propose a noval deep learning framework, namely TC-EMNet for disease progression modeling under both supervised and unsupervised setting.

  • TC-EMNet uses a combined recurrent neural network and variational auto-encoder (VAE) architecture to capture the irregularity in data and heterogeneity nature of disease.

  • Under superviesd setting, TC-EMNet employs dual memory network architecture to leverage both hidden representation from the input data and clinical diagnosis to produce accurate patient representations.

  • Experiments on two world datasets shows that TC-EMNet yields competitive clustering performance over state-of-the-art methods and is able to find clinically interpretable disease clusters/stages.

The remainder of the paper is organized as follows. Section II briefly reviews existing works related to DPM, temporal clustering and VAE. Section III describes the technical details of the proposed model (TC-EMNet). Section IV and V presents experimental results and discussions. Finally, Section VI concludes the paper.

Fig. 1: Overview of the proposed framework. At each time stamp, hidden representation from the encoder network is updated with the memory state to produce disease clusters/stages based on current and previous observations.

Ii Related Work

Ii-a Disease Progression Modeling

Disease progression modeling (DPM) plays a very important role in the healthcare domain, especially for chronic diseases such as Parkinson’s Disease (PD) and Alzheimer’s Disease (AD). A well-performed disease progression modeling system can not only provide early detection or diagnosis but also discover clinically meaningful patterns for certain groups of trajectories. There are increasing effort in using traditional statistic methods to tackle disease progression problems. Most probabilistic model are based on hidden markov model (HMM) due to the nature of disease progression. For example,

[1] derived a deep probabilistic model based on sequence-to-sequence architecture to model progression dynamic on UK Cystic Fibrosis registry. [36] introduced a continuous-time markov process to learn discrete representation of each progression state. Moreover, deep learning methods have also been developed for disease progression modeling. [19] proposed CNN-based model to jointly learn features from MR images combined with demographic information to predict Alzheimer’s Disease progression patterns. [32] designed a prediction framework using generative models to forecast the distribution of patients’ outcome. DPM can be regarded as a classification problem , where diagnosis label are leveraged in favor of model training. On the other hand, DPM can also be seen from an unsupervised perspective where the goal is to discover potential disease states or patient subtypes throughout patients’ medical history [6].

Ii-B Temporal clustering

Temporal clustering, widely known as time-series clustering, is a data-driven method to cluster patients into subgroups based on time-series observation. Temporal clustering can be considered as a challenging task often because of the high dimensionality of the dataset and multiple time steps for each individual. Recent advances have been focused on leveraging the latent representation learned by recurrent neural network (RNN) for temporal clustering, which was motivated by the success of RNN modeling time-series data. Moreover, due to the emerging availability of electronic health record (EHR) that introduced large-scale and normalized context for individual patients, deep learning approach become capable of learning more comprehensive patterns and achieving better performance on several critical tasks. [3]

introduces time-aware mechanism to long short term memory cells to capture progression patterns with irregular time-interval.

[18] proposed an actor-critic algorithm for predictive clustering where, instead of defining a similarity measure for clustering, a cluster embedding is trained to represent each disease stages. [37] proposed an auto-encoder to reconstruct relevant features for sepsis with attentions and showed that the proposed model is able to identify interpretable patient subtypes.

Ii-C Variational Autoencoder

Variational autoencoder (VAE) is a type of generative models that can handle complicate distributions. VAEs have been shown to be effective against modeling complex data structures and are widely adopt to solve many real world problems range from image generation, to anomaly detection

[4, 11, chien2017variational]. It has also several successful application with healthcare data [29]. VAE has weak assumption of generative process and can be trained end-to-end through neural networks. The training could be regarded as a generative process where a set of data points is drawn from the distribution to approximate the true underlying distribution. [15]

proposed to use VAE to impute missing values for electronic health data with uncertainty-aware attention. Experiments on real world datasets show that VAE is able to capture the complexity of EHR distribution.

[32] leveraged VAE framework to forecast disease states for Parkinson’s Disease (PD) and Alzheimer’s Disease (AD).

Iii Methodology

Iii-a Problem Definition

Let and

be the random variables for input feature space and label space accordingly. Here we focus on a clustering problem, where we are given a population of time-series data

consisted of paired sequences of observations for patients. denotes the time stamps for each patients at which the observations are made.

We aim to identify clusters for time-series data, each corresponding to a disease stage. Each cluster consists of homogeneous data samples, represented by the centroids based on certain similarity measure.

Iii-B Method

This section presents our proposed framework. Here we discuss disease progression modeling under both supervised and unsupervised settings, where our proposed question requires to estimate the underlying distribution of all possible disease stages. Such DPM framework is able to help the doctors identify meaningful characteristics in both times when a disease has certain diagnosis labels but possible underlying disease stages and when a disease has no well-defined diagnosis labels.

The framework consists of three components: the encoder, the memory network, and the clustering network. For each patient, a recurrent neural network is deployed to encode patient’s information. The memory network controls the overall long-term information at each time stamp. Specifically, when a hidden representation is generated based on current and previous observation at time stamp , the hidden state is read by the memory network and updates the memory storage. Next, a latent variable is drawn from the prior distribution conditioned on the hidden state that is generated from the memory network. Then, we either yield prediction outcomes or reconstruct the current observation accordingly. We take the hidden presentation from the last layer of the model for clustering.

Iii-B1 Encoder Network

The encoder network takes the current observation and the hidden state from previous time stamp and yields the hidden representation that can interacts with the external network. Specifically, a LSTM cell is adopted to generate and update the hidden state:


where is the current observation at time stamp and is the hidden state from previous step. At each time stamp, the encoder network maps the a sequence of time-series input to a hidden representation , where is the subspace of latent representation. The hidden representation will be interacts with external memory network to form accurate representation.

Iii-B2 Memory Network

Long term information plays an important role in disease progression modeling, since, in the context of chronic disease, the health conditions from the past will affect the current disease stages of the patient. In addition, historical information should be stored in a efficient way such that it can provide useful guidance towards patients current health states at different time stamps. To this end, we propose to an external memory network to capture long-term information throughout the progression modeling process. Our proposed memory network is closely related to [30]

, which has several successful application in the field of natural language processing. Similarly, we define memory slots to represent historical information that can be extract and retrieved at any given time stamp. At each time stamp, the hidden state from the encoder network is recorded and read into the memory network. By pushing through a series of observations, the memory network will process continuous representations with respect to each individual visits so that a more comprehensive review of the patient can be utilized during clustering/staging process.

Memory Reading

We denote a clinical sequences record , where t stands for index or time stamps of the given record. In memory network, after receiving a hidden representation from the encoder network, the network will produce an external representation based on reading weight of the memory slots. Specifically, can be expressed as:


where denotes the number of memory slots, is the memory representation with hidden size .

is strength vector that can be learned through the reading operation and

is the cosine similarity measure. Memory reading operation is build upon the idea that not all records in the sequence contribute equally to the current health state of the patient. Hence, the weights are computed using the softmax function based on the cosine similarity of the current hidden states and all the previous memories.

Fig. 2: Overview of the proposed memory network. Hidden states are first write in to the memory cells and read by the clustering network to produce a comprehensive representation.
Memory Writing

Memory writing stores latent representation into memory slots. We use a fixed number of slots to denote the overall memory size. The dimension of the continuous space for each memory slots is and we use to denote the dimension of hidden representation . The hidden state is non-linearly projected into the memory space using a matrix A, , where is the new input memory representation. Memory writing aims at filter out the non-related information and stores only personalized information based on the current hidden state. Mathematically, memory writing can expressed as:


where and is a gated vector that controls the information flow between the previous and current memory vector.

Iii-B3 Clustering Network

After obtaining the representation of the observation through the encoding network, i.e the prior network, and updating the memory cell at current time stamp , we follow the traditional framework of variational autoencoder (VAE) [20]

to compute the mean and standard deviation vectors through the posterior network. We assume that the output is a Gaussian distribution. The computation process can be expressed as:


where is the hidden state and is the observation at time stamp .

is posterior functions described by feed-forward neural networks. We then draw samples from the posterior Gaussian distribution using the reparameterization trick:


where , and is the latent representation. indicates element-wise multiplication. The reparameterization trick allows the gradient to back propagate through the sampling process. Lastly, depends on the availability of diagnosis label, the clustering network will be trained on two different objectives. When diagnosis label is used, the clustering network is directly trained to predict the label information:


, where

is a feedforward network to output probabilities of each label. When diagnosis label is not available, we trained the framework to reconstruct the observation

from the latent variable conditioned on the memory state , denote as:


where is the reconstructed input, is a feed forward network and is the concatenation. During cluster phase, we use euclidean distance-based k-means algorithm on the latent variable .

1:Initialize encoder and decoder network parameters ;
2:Initialize memory embedding and memory slots;
3:for (every time stamp do
4:     Compute patient hidden encoding through Encoder
5:     network via Eq. (1);
6:     Read from global-level memory network to extract
7:     recent memory representations via Eq. (2);
8:     if diagnosis is available then
9:         Read from patient-level memory network
10:         to extract recent memory representations
11:         via Eq. (2);
12:         Compute memory representation via Eq. (8);
13:     end if
14:     Compute loss via Eq. (4) - (7);
15:     Write to corresponding memory slots via Eq. (3);
16:end for
17:Update parameters by optimizing Eq. (12), (13) accordingly;
Algorithm 1 TC-EMNet

Iii-B4 Dual Memory Network Architecture

Under clinical setting, doctors often provide diagnosis labels based on patient current and past medical events. Such information can be target health conditions or a diagnosis. Under supervised setting when label is available during training, we further utilizes a patient-level memory network to capture diagnosis information during each visit. Compared to global-level memory network, patient-level memory network at current memory slot can only access diagnosis up to previous time stamp, namely, . patient-level memory network only reads and write diagnosis information which later is combined with global-memory network for clustering. We propose a calibration process to integrate representations from two memory networks, as follows:


where and is the global-level and patient-level memory network respectively. This memory calibration process can be regarded as a point-wise attention mechanism.

Iii-C Objective Function and Optimization

Here, we present our training objectives and optimization process. As mentioned in previous sections, the entire network can be trained from end-to-end using maximum likelihood estimation (MLE). In order to solve the intractable marginalization with respect to the latent variable , we use the variational lower bound parameterized by to approximate the true distribution, which we assume to be Gaussian. After the memory work reading and writing, We use the latent variable at time stamp to identify the disease stages. Here we restrict the latent variable to be a multivariate Gaussian distribution, which enforces the same for the posterior. We learn the generative parameter using maximum likelihood estimation (MLE):


However, the marginalization of is intractable for complicated functions (for instance neural networks). Thus, we need to derive a variational lower bound (i.e. variational Bayesian method) to approximate the logarithm of the marginal probability of the observation, which is as follows:


where the inequality can be obtained using Jensen’s inequality and the variational lower bound involves the probability that are parameterized by , which eventually approximate the intractable true posterior distribution . Since health-related data is often associated with high-dimensional and general more complicated distribution, we introduce the latent variable to capture the internal stochasticity from the data. We can train the entire clustering network end-to-end using stochastic optimization techniques. After obtaining the variational lower bound, the optimization follows the KL divergence that is the difference of log likelihood and the variational lower bound:


where and represents the generative parameter and proxy posterior accordingly. The equation holds if the distribution of is equal to the true distribution. When diagnosis label is used during training, we use the cross entropy loss to directly predict the outcome from the combined latent representation denoted as:


When the model is trained in a unsupervised manner, the overall objective function combined with the reconstruction loss becomes:


where we use the mean square error (MSE) for reconstruction loss and

is a hyperparameter to prevent VAE from KL vanishing problem. We adopt a linear annealing schedule for

based on training steps denoted as:


where is a threshold value. Last but not least, we use k-means algorithm [35] on the patient representation to perform clustering.

w/o label with label
Model Purity NMI RI Purity NMI RI
RNN 0.67990.00 0.14150.01 0.14060.02 0.85320.00 0.40200.01 0.38050.01
Bi-LSTM 0.68100.02 0.15400.02 0.15590.02 0.86740.00 0.40920.01 0.40420.02
RETAIN 0.69030.02 0.17870.01 0.16710.01 0.71440.02 0.25720.01 0.18380.03
Dipole 0.68390.00 0.17070.01 0.14520.00 0.89040.01 0.46740.01 0.47760.02
StageNet 0.69430.01 0.20020.01 0.17910.01 0.85130.01 0.40450.03 0.37440.01
AC-TPC - - - 0.82140.03 0.33620.07 0.38270.09
VAE 0.66510.02 0.10230.02 0.11170.02 0.64950.04 0.17180.05 0.10420.04
Memory Network 0.68870.02 0.13920.01 0.15840.02 0.82620.01 0.36030.01 0.35380.02
0.70400.01 0.19670.02 0.18910.02 0.89040.00 0.46790.01 0.48890.01
- - - 0.91260.01 0.47890.01 0.49230.02
TABLE I: Results of proposed methods and other methods on ADNI datasets. indicates that the smaller the better (0=best, and 1=worst). indicates that the greater the better (0=worst, and 1=best).
w/o label with label
Model Purity NMI RI Purity NMI RI
RNN 0.72210.00 0.30890.01 0.31200.01 0.76400.02 0.42220.04 0.36630.03
Bi-LSTM 0.72640.00 0.31700.00 0.29760.01 0.76740.03 0.44560.05 0.35750.05
RETAIN 0.52410.02 0.11880.01 0.06190.01 0.75100.01 0.40720.03 0.33610.01
Dipole 0.72330.00 0.32000.00 0.31530.00 0.80330.01 0.49470.01 0.44760.02
StageNet 0.72520.01 0.33050.00 0.32340.01 0.78390.01 0.47000.03 0.38400.01
AC-TPC - - - 0.81510.01 0.49840.03 0.51290.01
VAE 0.71610.00 0.35760.01 0.31530.00 0.79420.01 0.44520.00 0.37820.01
Memory Network 0.69960.01 0.28090.01 0.2581 0.02 0.76890.01 0.44820.01 0.45970.01
0.74520.00 0.37730.00 0.37420.01 0.82560.00 0.50530.00 0.48230.01
- - - 0.83390.00 0.50350.00 0.49930.01
TABLE II: Results of proposed methods and other methods on PPMI datasets. indicates that the smaller the better (0=best, and 1=worst). indicates that the greater the better (0=worst, and 1=best).

Iv Experiments

We evaluated our proposed model with two real-world datasets, Alzheimer’s Disease Neuroimaging Initiative (ADNI) and Parkinson’s Progression Markers Initiative (PPMI) dataset. All dataset can be accessed on IDA website111https://ida.loni.usc.edu/. The code can be found on GitHub222https://github.com/Ericzhang1/TC-EMNet.git.

Iv-a Datasets

Iv-A1 ADNI Dataset

Alzheimer’s disease (AD) is a chronic neurodegenerative disease that often related to behavior and cognitive impairment. ADNI is a longitudinal study that aims to explore early detection and tracking of AD based on imaging, bio-markers, and genetic data collected throughout the process

[14]. The dataset consists of total of 11651 visits over 1346 patients with 6 months interval. For each patient, 21 variables are collected and processed, including 16 time-varying feature (brain function, cognitive tests) and 5 static feature (background, demographics). 3 diagnose labels are assigned by doctors at each visit for patient, including control normal (CN), Mild Cognitive Impairment (MCI), and AD, which indicates the severity of how AD symptoms have progressed on each patients.

Iv-A2 PPMI Dataset

Parkinson’s Progression Markers Initiative (PPMI) is a longitudinal study aim to evaluate patients’ progression on parkinson’s disease (PD) based on biomarkers [23]. The dataset consists of total of 13685 visits over 2145 paitnets with irregular time interval. For each patient, 79 features based on motor and non-motor symptoms are collected, including cognitive assessment, lab tests, demographic information and biospecimens. Since the dataset does not provide diagnosis label per visit for each patient, we use Hoehn and Yahr (HY) score as labels for our evaluation. HY scores, ranges from 0 to 5 indicates the severity of patients symptoms on Parkinson’s disease. We use mean and last occurrence carried forward method to impute missing values.

Iv-B Baselines

We compare our proposed model to several state-of-the-art methods, ranged from vanilla RNNs to multi-layer attention models. Since here we consider disease progression modeling under both supervised and unsupervised setting, we adjusted the architecture of the baseline models to fit the objective accordingly. For baselines that cannot be modified interchangeably, we did not collect the result under the corresponding setting. For all experiment, we use k-means clustering on the hidden representations from the last layer to report the clustering performance.

  • RNN [25]: A single RNN cell with an additional layer of feed forward neural network. The model is trained with cross entropy loss and reconstruction objective accordingly.

  • Bi-LSTM [12]:Similar to RNN model, a Bi-directional LSTM is used with a reconstruction objective, the model takes both direction of the sequence data into account and is showed to capture richer information compare to single direction.

  • RETAIN [5]: An interpretable deep learning model that is based on recurrent neural network and reverse time attention mechanism. The RETAIN model learns the importance between hospital records through attention weights. We modify the last layer of RETAIN and train the model based on prediction and reconstruction objective.

  • Dipole [21]

    : A interpretable bidirectional recurrent neural network that employs attention mechanism to leverage both past and future visits. We use concatenation-based attention mechanism for testing and, similar to REATIN, we adjust the last layer of the model accordingly.

  • StageNet [10]: A recent risk prediction model that learned to extract disease progression patterns during training and leveraged modified LSTM cell with attention mechanism. The progression pattern at each time stamp is re-calibrate accordingly using convolution network.

  • AC-TPC [18]: A recent deep predictive clustering network that consists of a encoder network, selector and a predictor. The model is first initialized using a prediction objective and then optimized to train a cluster embedding using the actor-critic algorithm. This method cannot be trained without label information.

  • VAE [17]: A vanilla variational autoencoder model using a LSTM cell as encoder and trained with prediction and variation objective repectively. Note that This baseline method can be served as an ablation example against our proposed method.

  • Memory Network: A vanilla global-level memory network with reading and writing mechanism described previously. The network read and write EHR sequence directly and the k-means algorithm is applied directly to the hidden memory representation.

  • : Unsupvised version of TC-EMNet. When training label is not available, only global-level memory network is used to produce memory representation. We also train the model for prediction task and set it as an ablation example against supervised version of TC-EMNet.

  • : Supervised version of TC-EMNet. When training label is available, patient-level memory network is used to combine with global-level memory network to produce target-aware memory representations.

Hyperparameter Range
hidden size
latent variable size
learning rate
batch size
TABLE III: Hyperparameter Searching Space
(a) Bilstm
(b) StageNet
(c) Dipole
(d) Ours
(e) Bilstm
(f) StageNet
(g) Dipole
(h) Ours
Fig. 3: Visualiztion of the clusters for ADNI (first row) and PPMI (second row) using PCA: Bilstm (1st column), StageNet (2nd column), Dipole (3rd column), Ours (4th) column).
ADNI Dataset
Cluster I RAVLT_learning Ventricles WholeBrain ICV
RAVLT_perc_forgetting RAVLT_forgetting ADAS13 RAVLT_immediate
Cluster II ICV RAVLT_perc_forgetting ADAS13 Ventricles
serial RAVLT_immediate CDRSB
Cluster III RAVLT_perc_forgetting serial ICV RAVLT_learning
Entorhinal Hippocampus Ventricles WholeBrain
PPMI Dataset
Cluster I Global Spontaneity of Movement Speech Anxious Mood Arising from Chair
Right leg Getting Out of Bed Pronation-Supination (left)
Cluster II Posture Rest tremor amplitude Dopamine Rigidity
Saliva + Drooling Anxious Mood Global Spontaneity of Movement
Cluster III Postural Stability Cognitive Impairment Rest Tremor Amplitude Pronation-Supination (left)
Dopamine Standing Rigidity
Cluster IV Pronation-Supination (left) Standing Postural Stability Chewing
Cognitive Impairment Dopamine Right Hand
Cluster V Dopamine Cognitive Impairment Hallucinations Chewing
Dressing Pronation-Supination (left) Arising from Chair
Cluster VI Rigidity Serial Rigidity Standing
Apathy Constipation Problems Cognitive Impairment Dopamine
TABLE IV: Most significant features in each cluster measured by first order gradient for ADNI and PPMI dataset.
Model # of trainable parameters
Dipole 279k
StageNet 283k
AC-TPC 143k
TABLE V: Complexity comparison between models

Iv-C Model Training and Implementation Details

As mentioned previously, our proposed network is continuous and differentialble. We can train the network using stochastic optimization techniques. All neural networks in the proposed network are feed forward network. We implemented our solution using Pytorch

[28] and trained the model on single Nvidia Volta V100 GPU with 16GB memory. We adopt gradient accumulation when dealing with out of memory problems. We select hyperparameters through random search as shown in table III. For our model, we set both hidden size and latent variable size to be 128. We adopt Adam optimizer with learning rate of . The model is trained with batch size for epochs. is set to . We split the dataset into and report the performance of fold cross validation for both datasets. A detailed description of the optimization process of our proposed framework can be found in Algorithm 1. The average runtime of our proposed framework on both datasets is about 2 hours for cross validation. For implementation of other baseline methods, we implement RNN and Bi-lstm methods with Pytorch. We adopt implementations from Pyhealth [39] for RETAIN, Dipole and StageNet model. And we adopt implementation from [18] for AC-TPC. All baseline methods share the same hyperparameter searching space.

Iv-D Evaluation Metrics

To evaluate the clustering performance of our model, we use purity score (purity), normalized mutual information (NMI) [34], and adjusted rand index (ARI) [13]. Purity score is ranged between to , indicating the extent to which a cluster is consist of single class. NMI (0 to 1) represents the mutual information between each clusters with being perfect clustering. ARI derives from the Rand index and measures the percentage of the correct cluster assignment. Mathematically, the metrics can be expressed as follows:


where is the total number of samples, and denotes the cluster assignment and true label respectively, is the mutual information function and is the entropy, and RI are the expectation value and Rand index accordingly.

V Results

V-a Clustering Performance

A quantitative comparison of the clustering perofrmance on ADNI and PPMI dataset is shown in table I and table II respectively. We set the cluster assignments to the number class/diagnosis for each dataset, i.e. for ADNI (diagnosis label) and

(NHY score) for PPMI. We want the model to identify the individual disease stages both when there is only limited knowledge known to a certain disease, i.e class/diagnosis is not available and when diagnosis label is available, and thus provide insightful and interpretable information to help discover corresponding treatment to individual treatment. We compare our proposed method with the aforementioned baselines in terms of clustering performance. It is clear that our method has demonstrate competitive performance against all baseline methods across all evaluation metrics for both datasets. We note that it is generally difficult to identify clusters without the presence of label information as indicated by low NMI and RI scores. However, TC-EMNet outperforms baseline by a large margin in terms of NMI and RI scores when clustering with label. Training under supervised setting yields significantly better clustering performance compared to training under unsupervised setting. This is due to fact that the correlation between diagnosis and input features are encoded into each hidden representation. Although AC-TPC has better performance in terms of RI on the PPMI dataset. The method relies on pre-training the model with over

epochs, which could result in model memorizing the input data. Both Dipole and StageNet has comparable performance. However, it is worth mentioning that both model has leverages attention over multi-layer RNNs, which introduces additional complexity to the model. A detailed comparison between the trainable parameters is shown in table V. Furthermore, we find that when training with label information, RI score can be negatively impact compared to training without labels. Such phenomena is observed for multiple baseline methods. One explanation could be directly leveraging label information overwhelms the training process since labels possess strong prediction power compared to input features, making the model more biased towards dominated class when dealing with imbalance datasets; thus, RI may drop as there are more false positives and false negatives. It also can be observed that leveraging external memory effectively captures long term information and the TC-EMNet is capable of learning complexity from the input data. The patient-level memory network constructively binds with the global-level memory network to produce more comprehensive memory representations.

Fig. 4: Significant feature values of cluster centriods on ADNI dataset. The distribution of clusters are very different, which means distinct subtypes.
Fig. 5: Significant feature values of cluster centriods on PPMI dataset. The distribution of clusters are very different, which means distinct subtypes.

V-B Disease Stage

In order to interpret the disease stages and progression patterns found by TC-EMNet. We first selected three baseline models that have comparable performance against TC-EMNet and visualized the hidden representations in space using PCA [24]. The results are shown in Fig 3. We observed that in general most methods can produce distinct clusters for the ADNI dataset. However, for PPMI dataset, most baseline methods failed at producing effective clusters, whereas TC-EMNet produces distinct clustering results. This shows that TC-EMNet is able to constructively model long term information between each visit to order to find effective representations. Next, we compute feature importance for every cluster based on the weights from the last layer of the network. The results are shown in table IV. It can be observed that for both dataset, each cluster is determined by diverse range of features, which means it is easier to identify each patients’ progression patterns through observation. We also compute the centroid values for each cluster and plot the distribution in Fig 4, 5 for ADNI and PPMI dataset respectively. For ADNI dataset, our proposed model has determined significant features such as: RAVLT_learning, RAVLT_perc_forgetting, ICV, ventricles. Rey’s Auditory Verbal Learning Test (RAVLT) scores are helpful in testing episodic memories and are very important indicators in identifying patient’s progression in Alzheimer’s disease [26]. In particular, learning test (RAVLT_learning) and percent forgetting test (RAVLT_perc_forgetting) are highly correlated and thus becomes crucial biomarks for early detection in AD. It can be observed in Fig 4 that three clusters produced by our model has wide distribution for RAVLT testing values, which suggests three different patient subtypes. As for PPMI dataset, our model has found that the dopamine dysregulation syndrome (Dopamine) is a significant feature in identifying clusters. Studies have discover that under clinical settings early characterization of Dopamine can aid the treatment for motor and non-motor complications for Parkinson’s disease [7]. There are also studies showed that cognitive impairment (Cognitive impairment) is a strong indicator for Parkinson’s disease. Difference in cognitive impairment scores can reflect advanced progression in PD [33].

Vi Conclusion

In this paper, we propose TC-EMNet for disease progression modeling on time-series data. TC-EMNet leverages VAE to model data irregularity and an external memory network to capture long term dependency. We developed TC-EMNet to perform patient clustering/subtyping under both supervised and unsupervised setting. Under supervised setting, TC-EMNet leverages a dual memory network architecture to extract target-aware information from diagnosis to compute patient representations. Throughout the experiment on two real-world dataset we showed that our model out-performs state-of-the-art methods and is able to identify interpretable disease stages that are clinically meaningful. TC-EMNet yields competitive clustering performance with limited complexity. In real world clinical setting, we hope that our model could help physicians identify patient’s progression patterns and discover potential disease stages to gain more understanding about chronic and other heterogeneous diseases.


This paper was funded in part by the National Science Foundation under award number CBET-2037398.


  • [1] A. M. Alaa and M. van der Schaar (2019) Attentive state-space modeling of disease progression. In Advances in Neural Information Processing Systems, pp. 11338–11348. Cited by: §II-A.
  • [2] S. Auer and B. Reisberg (1997) The gds/fast staging system. International Psychogeriatrics 9 (S1), pp. 167–171. Cited by: §I.
  • [3] I. M. Baytas, C. Xiao, X. Zhang, F. Wang, A. K. Jain, and J. Zhou (2017) Patient subtyping via time-aware lstm networks. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, pp. 65–74. Cited by: §II-B.
  • [4] C. P. Burgess, I. Higgins, A. Pal, L. Matthey, N. Watters, G. Desjardins, and A. Lerchner (2018) Understanding disentangling in -vae. arXiv preprint arXiv:1804.03599. Cited by: §II-C.
  • [5] E. Choi, M. T. Bahadori, J. A. Kulas, A. Schuetz, W. F. Stewart, and J. Sun (2016) Retain: an interpretable predictive model for healthcare using reverse time attention mechanism. arXiv preprint arXiv:1608.05745. Cited by: 3rd item.
  • [6] J. M. Dennis, B. M. Shields, W. E. Henley, A. G. Jones, and A. T. Hattersley (2019) Disease progression and treatment response in data-driven subgroups of type 2 diabetes compared with models based on simple clinical features: an analysis using clinical trial data. The lancet Diabetes & endocrinology 7 (6), pp. 442–451. Cited by: §II-A.
  • [7] A. H. Evans and A. J. Lees (2004) Dopamine dysregulation syndrome in parkinson’s disease. Current opinion in neurology 17 (4), pp. 393–398. Cited by: §V-B.
  • [8] M. Ferrer, J. Alonso, J. Morera, R. M. Marrades, A. Khalaf, M. C. Aguar, V. Plaza, L. Prieto, and J. M. Anto (1997) Chronic obstructive pulmonary disease stage and health-related quality of life. Annals of internal Medicine 127 (12), pp. 1072–1079. Cited by: §I.
  • [9] V. Fortuin, M. Hüser, F. Locatello, H. Strathmann, and G. Rätsch (2018) Som-vae: interpretable discrete representation learning on time series. arXiv preprint arXiv:1806.02199. Cited by: §I.
  • [10] J. Gao, C. Xiao, Y. Wang, W. Tang, L. M. Glass, and J. Sun (2020) Stagenet: stage-aware neural networks for health risk prediction. In Proceedings of The Web Conference 2020, pp. 530–540. Cited by: §I, 5th item.
  • [11] I. Higgins, L. Matthey, A. Pal, C. Burgess, X. Glorot, M. Botvinick, S. Mohamed, and A. Lerchner (2016) Beta-vae: learning basic visual concepts with a constrained variational framework. Cited by: §II-C.
  • [12] Z. Huang, W. Xu, and K. Yu (2015) Bidirectional lstm-crf models for sequence tagging. arXiv preprint arXiv:1508.01991. Cited by: 2nd item.
  • [13] L. Hubert and P. Arabie (1985) Comparing partitions. Journal of classification 2 (1), pp. 193–218. Cited by: §IV-D.
  • [14] C. R. Jack Jr, M. A. Bernstein, N. C. Fox, P. Thompson, G. Alexander, D. Harvey, B. Borowski, P. J. Britson, J. L. Whitwell, C. Ward, et al. (2008) The alzheimer’s disease neuroimaging initiative (adni): mri methods. Journal of Magnetic Resonance Imaging: An Official Journal of the International Society for Magnetic Resonance in Medicine 27 (4), pp. 685–691. Cited by: §IV-A1.
  • [15] E. Jun, A. W. Mulyadi, and H. Suk (2019) Stochastic imputation and uncertainty-aware attention to ehr for mortality prediction. In 2019 International Joint Conference on Neural Networks (IJCNN), pp. 1–7. Cited by: §II-C.
  • [16] A. A. Kehagia, R. A. Barker, and T. W. Robbins (2010) Neuropsychological and clinical heterogeneity of cognitive impairment and dementia in patients with parkinson’s disease. The Lancet Neurology 9 (12), pp. 1200–1213. Cited by: §I.
  • [17] M. J. Kusner, B. Paige, and J. M. Hernández-Lobato (2017) Grammar variational autoencoder. In International Conference on Machine Learning, pp. 1945–1954. Cited by: 7th item.
  • [18] C. Lee and M. Van Der Schaar (2020) Temporal phenotyping using deep predictive clustering of disease progression. In International Conference on Machine Learning, pp. 5767–5777. Cited by: §I, §II-B, 6th item, §IV-C.
  • [19] M. Liu, J. Zhang, E. Adeli, and D. Shen (2018) Joint classification and regression via deep multi-task multi-channel learning for alzheimer’s disease diagnosis. IEEE Transactions on Biomedical Engineering 66 (5), pp. 1195–1206. Cited by: §II-A.
  • [20] M. Lopez-Martin, B. Carro, A. Sanchez-Esguevillas, and J. Lloret (2017) Conditional variational autoencoder for prediction and feature recovery applied to intrusion detection in iot. Sensors 17 (9), pp. 1967. Cited by: §III-B3.
  • [21] F. Ma, R. Chitta, J. Zhou, Q. You, T. Sun, and J. Gao (2017) Dipole: diagnosis prediction in healthcare via attention-based bidirectional recurrent neural networks. In Proceedings of the 23rd ACM SIGKDD international conference on knowledge discovery and data mining, pp. 1903–1911. Cited by: 4th item.
  • [22] T. Ma, C. Xiao, and F. Wang (2018) Health-atm: a deep architecture for multifaceted patient health record representation and risk prediction. In Proceedings of the 2018 SIAM International Conference on Data Mining, pp. 261–269. Cited by: §I.
  • [23] K. Marek, D. Jennings, S. Lasch, A. Siderowf, C. Tanner, T. Simuni, C. Coffey, K. Kieburtz, E. Flagg, S. Chowdhury, et al. (2011) The parkinson progression marker initiative (ppmi). Progress in neurobiology 95 (4), pp. 629–635. Cited by: §IV-A2.
  • [24] A. M. Martinez and A. C. Kak (2001) Pca versus lda. IEEE transactions on pattern analysis and machine intelligence 23 (2), pp. 228–233. Cited by: §V-B.
  • [25] T. Mikolov, M. Karafiát, L. Burget, J. Cernockỳ, and S. Khudanpur (2010) Recurrent neural network based language model.. In Interspeech, Vol. 2, pp. 1045–1048. Cited by: 1st item.
  • [26] E. Moradi, I. Hallikainen, T. Hänninen, J. Tohka, A. D. N. Initiative, et al. (2017) Rey’s auditory verbal learning test scores can be predicted from whole brain mri in alzheimer’s disease. NeuroImage: Clinical 13, pp. 415–427. Cited by: §V-B.
  • [27] L. Mou, P. Zhao, H. Xie, and Y. Chen (2019) T-lstm: a long short-term memory neural network enhanced by temporal information for traffic flow prediction. Ieee Access 7, pp. 98053–98060. Cited by: §I.
  • [28] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, et al. (2019) Pytorch: an imperative style, high-performance deep learning library. Advances in neural information processing systems 32, pp. 8026–8037. Cited by: §IV-C.
  • [29] B. Shickel, P. J. Tighe, A. Bihorac, and P. Rashidi (2017) Deep ehr: a survey of recent advances in deep learning techniques for electronic health record (ehr) analysis. IEEE journal of biomedical and health informatics 22 (5), pp. 1589–1604. Cited by: §II-C.
  • [30] S. Sukhbaatar, A. Szlam, J. Weston, and R. Fergus (2015) End-to-end memory networks. arXiv preprint arXiv:1503.08895. Cited by: §III-B2.
  • [31] Z. Sun, S. Ghosh, Y. Li, Y. Cheng, A. Mohan, C. Sampaio, and J. Hu (2019) A probabilistic disease progression modeling approach and its application to integrated huntington’s disease observational data. JAMIA open 2 (1), pp. 123–130. Cited by: §I.
  • [32] X. Teng, S. Pei, and Y. Lin (2020) StoCast: stochastic disease forecasting with progression uncertainty. IEEE Journal of Biomedical and Health Informatics 25 (3), pp. 850–861. Cited by: §II-A, §II-C.
  • [33] D. Verbaan, J. Marinus, M. Visser, S. M. van Rooden, A. M. Stiggelbout, H. A. Middelkoop, and J. J. van Hilten (2007) Cognitive impairment in parkinson’s disease. Journal of Neurology, Neurosurgery & Psychiatry 78 (11), pp. 1182–1187. Cited by: §V-B.
  • [34] N. X. Vinh, J. Epps, and J. Bailey (2010) Information theoretic measures for clusterings comparison: variants, properties, normalization and correction for chance. The Journal of Machine Learning Research 11, pp. 2837–2854. Cited by: §IV-D.
  • [35] K. Wagstaff, C. Cardie, S. Rogers, S. Schrödl, et al. (2001) Constrained k-means clustering with background knowledge. In Icml, Vol. 1, pp. 577–584. Cited by: §III-C.
  • [36] X. Wang, D. Sontag, and F. Wang (2014) Unsupervised learning of disease progression models. In Proceedings of the 20th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 85–94. Cited by: §I, §II-A.
  • [37] C. Yin, R. Liu, D. Zhang, and P. Zhang (2020) Identifying sepsis subphenotypes via time-aware multi-modal auto-encoder. In Proceedings of the 26th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 862–872. Cited by: §II-B.
  • [38] X. Zhang, J. Chou, J. Liang, C. Xiao, Y. Zhao, H. Sarva, C. Henchcliffe, and F. Wang (2019) Data-driven subtyping of parkinson’s disease using longitudinal clinical records: a cohort study. Scientific reports 9 (1), pp. 1–12. Cited by: §I.
  • [39] Y. Zhao, Z. Qiao, C. Xiao, L. Glass, and J. Sun (2021) PyHealth: a python library for health predictive models. arXiv preprint arXiv:2101.04209. Cited by: §IV-C.