In this new era of Big Data, large volumes of patient medical data are continuously being collected and are becoming increasingly available for research. Intelligent analysis of such large scale medical data can uncover valuable insights complementary to existing medical knowledge and improve the quality of care delivery. Among the various kinds of medical data available, longitudinal Electronic Health Records (EHR), that comprehensively capture the patient health information over time, have been proven to be one of the most important data sources for such studies. EHRs are routinely collected from clinical practice and the richness of the information they contain provides significant opportunities to apply AI techniques to extract nuggets of insight. Over the years, many researchers have postulated various temporal models of EHR for tasks such as early identification of heart failure sun2012combining , readmission prediction xiao2018readmission , and acute kidney injury prediction tomavsev2019clinically . For such analysis to be of practical use, the models should provide support for generating interpretations or post-hoc explanations. While the necessary properties of interpretations / explanations are still being debated jain2019attention , it is generally desirable to ascertain the importance of past events on model predictions at a particular time point. Furthermore, despite their initial success, RNN model applications for EHR also suffer from the inherent difficult to identify and control the temporal contents that should be memorized by these RNN models.
Contemporaneously, we have also witnessed tremendous architectural advances for temporal models that are aimed at better generalization capabilities. In particular, memory networks NTMGraves ; DNC ; Sukhbaatar2015a are an exciting class of architecture that aim to separate the process of learning the operations and the operands by using an external block of memory to memorize past events from the data. Such networks have been extensively applied to different problems and were found to generalize well DNC . However, there have been only a limited number of applications of memory networks for clinical data modeling prakash2017condensed ; le2018dual . One of the primary obstacle is the inherently difficult problem of identifying important past events due to the diffused manner in which such networks store past events in memories. While TARDIS ; TARDIS2
have explored the possibilities of using explicit memories that can store past events exactly and have found varying degrees of success, such models are difficult to train. The discontinuities arising from the discrete operations either necessitate learning with high levels of supervision such as REINFORCE with appropriate reward shaping or are learned using stochastic reparameterization under annealing routines and deal with high variance in gradients.
In this paper, we propose EBmRNN: a novel explicit-blurred memory architecture for longitudinal EHR analysis. Our model is inspired by the well-known Atkinson-Shiffrin model of human memory Atkinson:1968 . Our key contributions are as follows:
We propose a partitioning of external memory of generic memory networks into a blurred-explicit memory architecture that supports better interpretability and can be trained with limited supervision.
We evaluate the model over classification problems on longitudinal EHR data. Our results show EBmRNN achieves accuracies comparable to state-of-the-art architectures.
We discuss the support for interpretations inherent in EBmRNN and analyze the same over the different tasks.
Memory networks are a special class of Recurrent Neural Networks that employ external memory banks to store computed results. The separation between operands and operators provided by such architectures have been shown to increase network capacity and/or help generalize over datasets. However, the involved operations are in general highly complex and renders such networks very difficult to interpret. Our proposed architecture is shown in Figure1. The architecture is inspired by the Atkinson-Shiffrin model of cognition and is composed of three parts:
a controller (e.g. a LSTM network) that processes inputs sequentially and produces candidate memory representation at each time point
along with control vectors to manage the external memory. Mathematically, it can be expressed as follows:
an ‘explicit’ memory bank, where the generated candidate memory representation is stored. Depending on the outputs of a controlling read gate, the candidate memory can be stored explicitly or passed on to the blurred memory. When it is stored explicitly and the bank was already full, an older memory is removed based on the information content and passed on to the blurred memory bank. To update the memory explicitly, we discretely select the index by make use of the Gumbel-Softmax trick as shown below:
is a network learnt usage estimated.is a hyper-parameter capturing the effect of current reads on the slots and
is a one-hot encoded weight vector over memory slots.
The memory passed on to the blurred memory bank is diffused according to the control vectors and stored as high level concepts.
To generate outputs at time , the architecture makes use of a read gate to select the memories stored in explicit and blurred memory that are useful at that time point.
where, and are the blurred and explicit memories. is the read gate output and is the final output. The full model description is presented in the Appendix.
Experimental setup: We evaluated the performance of the proposed EBmRNN on the publicly available MIMIC III (Medical Information Mart for Intensive Care) data set mimic3 . The data includes vital signs, medications, laboratory measurements, observations and clinical notes. For this paper, we focused on the structured data fields and followed the MIMIC III benchmark proposed in mimic3bench to construct cohorts for specific learning tasks of great interest to the critical care community namely, ‘In-hospital mortality’, ‘decompensation’, and ’phenotype’ classification. To estimate the effectiveness of the EBmRNN
scheme, we compared it with the following baseline algorithms: Logistic Regression using the features used inmimic3bench EBmRNN that doesn’t have access to blurred memory, hereby referred to as EmRNN. Comparison with EmRNN allows the training to proceed via a direct path to explicit memories and hence estimate its effect more accurately. EmRNN is completely interpretable while EBmRNN is interpretable to the limit allowed by the complexities of the problem. Details on the exact cohort definitions and constructions are provided in mimic3bench . More details on the tasks are also presented in the Appendix.
All the models were trained for 100 epochs. We used the recommended setting for the baseline methods frommimic3bench . In this paper, we wanted to understand the relative importance of the memory banks and as such chose to study how the network uses the two different memory banks under similar capacity conditions. For EmRNN and EBmRNN
, the hyperparameters such as the memory size (), controller hidden size (), and the number of reads () were set using the validation set for each of the different datasets. While can be learned during the training process, following past work, we used a fixed value of . We chose a -layered GRU with dropout as our controller and the models were trained using SGD with momentum (
) along with gradient clipping. Table1 shows the AUC-ROC for the different tasks. Overall, we note that EBmRNN is on par or able to outperform each of the baselines for each of the tasks. Song et al. song2018attend found success with a multi-layered large transformer based model and can be considered the state-of-the art including all architectures. It is interesting to note that our results, using a single layer of memory, are comparable to the many-layered transformer approach - thus indicating the efficiency of the proposed architecture.
In the subsequent paragraphs, we discuss the key insights derived from
How to interpret EBmRNN? To analyze the interpretability inherent in the model, we picked a patient for each of the tasks under consideration. We used a trained model with slots and allowing reads to generate the predictions. As mentioned before, the explicit memory allows complete traceability of inputs by storing each input in a distinct memory slot. Figure 2 depicts the contents of the explicit memory over time discretized by 1 hour. Such slot utilization pattern provides an insight into the contents recognized by the network as being important for the task at hand. Furthermore, the plots also exhibit that the model is able to remember, explicitly, far-off time points for an extended period, before caching it into the blurred memory space.
How to interpret the influence of explicit memory? In addition to exact memory contents, we can also analyze the importance of the explicit memory for specific tasks by analyzing the control for the read gate over time. Figure 3 shows the temporal progression of the read gate for the patients from previous analysis for three distinct tasks. Interestingly, we can see the model using different patterns of usage for different tasks. While the network is assigning almost equal importance to both banks for in hospital mortality, it is placing high importance on explicit memory for phenotyping. This can also correlate with the improved performance for EmRNN for the phenotyping task.
Why do we need the blurred memory? Given the interpretability provided by the explicit memory, it may be tempting to avoid the use of blurred memory in favor of EmRNN. As our results indicate, such a model can perform well for certain tasks. However, for tasks such as “in-hosptial mortality”, the blurred memory provides the network with additional capacity. Also, from a practical point of view, we found the EmRNN difficult to train where inspite of the Gumbel-Softmax reparameterization trick, the gradients frequently exploded and required higher supervision. On the other hand, the presence of the blurred bank helped the training by providing a more tractable path. If the use case demands a higher value for intepretability, we recommend to either use a smaller sized blurred memory bank or perform relative regularization of the read gates for the blurred component.
In this work, we have introduced EBmRNN, a memory network architecture able to mimic the human memory models by combining sensory, explicit and long term memories for classification tasks. The proposed scheme achieves state-of-the-art levels of performances while being more interpretable, especially when explicit memories are utilized more. Our future work will aim at presenting such interpretations via end-to-end system following a user centered design approach.
- (1) Jimeng Sun et. al. Combining knowledge and data driven insights for identifying risk factors using electronic health records. In AMIA Annual Symposium Proceedings, volume 2012, page 901. American Medical Informatics Association, 2012.
- (2) Cao Xiao, Tengfei Ma, Adji B Dieng, David M Blei, and Fei Wang. Readmission prediction via deep contextual embedding of clinical concepts. PloS one, 13(4):e0195024, 2018.
- (3) Nenad Tomašev, Xavier Glorot, Jack W Rae, Michal Zielinski, Harry Askham, Andre Saraiva, Anne Mottram, Clemens Meyer, Suman Ravuri, Ivan Protsyuk, et al. A clinically applicable approach to continuous prediction of future acute kidney injury. Nature, 572(7767):116, 2019.
- (4) Sarthak Jain and Byron C Wallace. Attention is not explanation. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 3543–3556, 2019.
- (5) Alex Graves, Greg Wayne, and Ivo Danihelka. Neural turing machines. CoRR, abs/1410.5401, 2014.
- (6) Alex Graves et. al. Hybrid computing using a neural network with dynamic external memory. Nature, 538(7626):471–476, October 2016.
- (7) Sainbayar Sukhbaatar, Arthur Szlam, Jason Weston, and Rob Fergus. End-To-End Memory Networks. In Proceedings of the International Conference on Neural Information Processing Systems (NIPS), mar 2015.
- (8) Aaditya Prakash, Siyuan Zhao, Sadid A Hasan, Vivek V Datla, Kathy Lee, Ashequl Qadir, Joey Liu, and Oladimeji Farri. Condensed memory networks for clinical diagnostic inferencing. In AAAI, pages 3274–3280, 2017.
- (9) Hung Le, Truyen Tran, and Svetha Venkatesh. Dual memory neural computer for asynchronous two-view sequential learning. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pages 1637–1645. ACM, 2018.
- (10) Çaglar Gülçehre, Sarath Chandar, and Yoshua Bengio. Memory augmented neural networks with wormhole connections. CoRR, abs/1701.08718, 2017.
Caglar Gulcehre, Sarath Chandar, Kyunghyun Cho, and Y Bengio.
Dynamic neural turing machine with continuous and discrete addressing schemes.Neural Computation, 30:1–28, 01 2018.
- (12) R. C. Atkinson and R. M. Shiffrin. Human memory: A proposed system and its control processes. In K. W. Spence and J. T. Spence (Eds.), The Psychology of learning and motivation: Advances in research and theory (vol. 2)., pages 89 – 105, 1968.
- (13) Alistair E.W. Johnson, Tom J. Pollard, Lu Shen, Li-wei H. Lehman, Mengling Feng, Mohammad Ghassemi, Benjamin Moody, Peter Szolovits, Leo Anthony Celi, and Roger G. Mark. Mimic-iii, a freely accessible critical care database. Scientific Data, 3:sdata201635, 2016.
- (14) Hrayr Harutyunyan, Hrant Khachatrian, David C. Kale, and Aram Galstyan. Multitask learning and benchmarking with clinical time series data. CoRR, abs/1703.07771, 2017.
Huan Song, Deepta Rajan, Jayaraman J Thiagarajan, and Andreas Spanias.
Attend and diagnose: Clinical time series analysis using attention models.In
Thirty-Second AAAI Conference on Artificial Intelligence, 2018.
Appendix A Data description
The dataset for each of the tasks is described below:
In Hospital Mortality Prediction: This task is a classification problem where the learning algorithm is asked to predict mortality using the first 48 hours of data collected on the patient for each ICU stay. All ICU stays for which the length of stay is unknown or less than 48 hours have been discarded from the study. Following exactly the benchmark cohort constructions proposed in mimic3bench , we were left with 17903 ICU stays for training and 3236 ICU stays for testing.
Decompensation Prediction: This task is a binary classification problem. Decompensation is synonymous to a rapid deterioration of health typically linked to very serious complications and prompting “track and trigger” initiatives by the medical staff. There are many ways to define decompensation. We adopt the approach used in mimic3bench
to represent the decompensation outcome as a binary variable indicating whether the patient will die in the next 24 hours. Consequently, data for each patient is labeled every hour with this binary outcome variable. The resulting data set for this task consists of 2,908,414 training instances and 523,208 testing instances as reported inmimic3bench with a decompensation rate of 2.06
This task is a multi label classification problem where the learning algorithm attempts to classify 25 common ICU conditions, including 12 critical ones such as respiratory failure and sepsis and 8 chronic comorbidities such as diabetes and metabolic disorders. This classification is performed at the ICU stay level, resulting in 35,621 instances in the training set and 6,281 instances in the testing set.
For each patient, the input data consists of an hourly vector of features containing average vital signs (e.g., heart rate, diastolic blood pressure), health assessment scores (e.g., Glasgow Come Scale) and various patient related demographics.
Appendix B Model Description
b.1 Explicit-Blurred Memory Augmented RNN
Let us denote the sequence of observations as , where is the length of the sequence and . Similarly, let us denote the set of desired outputs as . To model from , is fed sequentially to the proposed EBmRNN with parameters and hyper-parameters that will be defined below.
In EBmRNN, we split the conventional memory network architecture into two banks: (a) an explicit memory bank () and (b) a blurred or diffused memory bank(). Figure 1 shows a high level overview of the EBmRNN cell at time . This cell has access to an explicit memory bank to persist past events discretely. denotes the capacity of the memory and is the dimensionality of each memory slot. This cell also has access to a blurred or diffused memory where abstractions of important salient features from past observations are stored.
Observations at time are fed to this recurrent cell to produce an output based on the current input , the external explicit and blurred memories and . summarizes information extracted from both and that is deemed relevant for the generation of the output . is designed to contain enough abstraction of past observations seen by EBmRNN, including the current input so that specific tasks can generate a desired using only a shallow network outside of the cell. This design choice helps the interpretability of the model as it facilitates linking to memories in pointing explicitly to inputs , while still retaining the expressiveness of a blurred memory. Analyzing how EBmRNN is using provides a natural way to track how attentive EBmRNN is to input data stored in while analyzing EBmRNN’s focus on enables us to track the importance of long term dependencies. Details on how is computed are presented in the next subsection.
In addition to and , there are three primary components controlling the functioning of the cell:
The controller (), that senses inputs to EBmRNN and maps these inputs into control signals for the management of all read and write operations to the memory banks.
The read gate controlling read accesses to the memory banks from control signals emitted by the controller.
The write gate controlling writes into the memory banks from control signals emitted by the controller.
In the remainder of this section, we describe these three components in details.
b.1.1 The Controller
At each time point , the controller receives the current input and generates several outputs to manage and with appropriate read and write instructions sent to the read and write gates. As it receives , the controller updates its hidden state based on the past output of the cell , its past hidden state and current input . In addition to updating its hidden state , the controller emits two keys and to be used by the read gate to control access to memory contents from and . To control write operations, the controller also produces a representation of the that will be consumed by the write gate. represents information from that is candidate for a write into and . The controller also produces , an erased weight vector that will be consumed by the write gate to forget content from . In this work, we model the controller with standard recurrent neural network architectures such as Gated Recurrent Units (GRU) or Long Short Term Memory networks (LSTM). The operations of the controller are summarized below:
b.1.2 The Read Gate and Read Operations
The read gate enforces read accesses from and by consuming and and comparing these keys against the content of the two memory banks and . Using this addressing scheme, the following weight vectors over the memories are computed as follows:
denotes an appropriate distance function between the key vectors and the memory locations. For our purpose, we use the cosine similarity measure as a distance function.and . To ensure discrete access, weights are required to be one-hot encoded vectors. While Softmax is a natural choice for soft selection of indices for , its use is not applicable for the hard selection required for . Gumbel Softmax is a newer paradigm that is applicable in this context compared to alternatives like top-K Softmax that can introduce discontinuities. Gumbel Softmax uses a stochastic re-parameterization scheme to avoid non-differentiablities that arise from making discrete choices during normal model training. We use the straight-through optimization procedure that allows the network to make discrete decisions on the forward pass while estimating the gradients on the backward pass using Gumbel Softmax. More details on this scheme can be found from TARDIS .
The read vectors and from each of the banks are computed as follows:
and belong both to . We combine the two content reads from the two banks using a gate as follows:
while . The final output from EBmRNN can then be produced from a shallow layer that combines the contribution from the two memory banks represented by :
Equation 7 ensures that the network can learn to produce its desired output using information from either memory banks. The gated value controls the relative effect of the blurred and explicit memories on the output. On one hand, higher average values of would ensure that the network relies more on explicit memories and be as such easier to interpret. On the other hand, lower values of causes the network to rely more on blurred memories and be harder to interpret. Depending on the learning task at hand, there could be an interesting trade-off between learning performance and interpretability that can be controlled by this gating scheme. In fact, one could introduce a hyper-parameter in 7 to control this trade-off between and .
The read operations are repeated times to generate hops from the memory.
b.1.3 The Write Gate and Write Operations
Once memories are read, the controller updates the memory banks for the next state. At each time point, the controller generates the memory representations, , for the input . The update strategy for the two banks are slightly different, and we start by describing the explicit bank update first.
Explicit memory update: As long as the explicit bank is not full, newer memories are simply appended to it and the update equation can be given as:
Once the entire memory is filled up, the network needs to learn to forget less important memory slots to generate a filtered explicit memory and update the memory following equation 9. From an information theoretic intuition, more information can be retained by the network by sustaining a higher entropy within the memory banks. The network learns the importance of the old memories with respect to new memory candidate content as follows:
. Equation 10 only uses the content to generate the importance of the memory locations. Specifically, interpreting these values of
in terms of retention probabilities, locations with dissimilar contents will have higher retention probability - thereby forcing the network to store discriminative content in the explicit memory.
Past research has also shown that usage-based addressing can significantly improve the expressiveness of the network. We follow the scheme proposed by TARDIS and make use of an auxiliary variable that tracks a moving average of past read values for each memory locations of . The final write vector along with all the usage update is given as:
. is a hyper-parameter capturing the effect of current reads on the slots.
Although, other addressing mechanisms have been proposed in literature, we chose this setting for model simplicity and also to better capture the desirable properties of EHR applications.
The explicit bank is then updated by removing the slot with the highest value of ( from slot ) and replacing its content with . At that time, we also reset the usage value for the slot (i.e. ).
Similar to the read operations, is a one-hot encoded vector, the equations for the popped memory, and subsequently update of the explicit memory are given as below:
where represents a matrix of all 1 and represents the same for a dimensional vector.
Blurred memory update: The Blurred memories are used to represent past events with more abstract concepts that can capture long term dependencies. The memory bank provides a place for memories forgotten from the explicit bank to be stored in more abstract sense. also allows EBmRNN to track and access a higher dimensional construct of current memory representation.
We generate a candidate blurred memory using the following equation:
We generate write-vectors using a formulation similar to equation 10 by replacing the Gumbel-Softmax with a Softmax. The final update equation for the blurred memory can then be given as follows:
where is an erase weight generated by the controller.