Improved Patient Classification with Language Model Pretraining Over Clinical Notes

09/06/2019 ∙ by Jonas Kemp, et al. ∙ Google 4

Clinical notes in electronic health records contain highly heterogeneous writing styles, including non-standard terminology or abbreviations. Using these notes in predictive modeling has traditionally required preprocessing (e.g. taking frequent terms or topic modeling) that removes much of the richness of the source data. We propose a pretrained hierarchical recurrent neural network model that parses minimally processed clinical notes in an intuitive fashion, and show that it improves performance for multiple classification tasks on the Medical Information Mart for Intensive Care III (MIMIC-III) dataset, increasing top-5 recall to 89.7 diagnosis classification and AUPRC to 35.2 diagnosis classification compared to models that treat the notes as an unordered collection of terms or without pretraining. We also apply an attribution technique to several examples to identify the words and the nearby context that the model uses to make its prediction, and show the importance of the words' context.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

With the rapid deployment of electronic health records (EHRs) in the US, clinicians routinely enter patient data electronically, mostly in unstructured, free-text clinical notes. Some key details of the patient’s clinical assessment and medical history are stored almost exclusively in these notes, making them an important source of information for a number of downstream applications, such as clinical trial recruitment, billing, and predictive modeling. However, certain characteristics of clinical notes make automated parsing a challenge: clinicians employ many non-standard, ambiguous shorthand phrases (for example, “af” for “afebrile” or “atrial fibrillation”) and organize notes in unpredictable ways. These challenges make secondary use of free-text EHR data difficult, often requiring manual chart abstraction by trained staff to pull key information from notes. Traditional natural language processing (NLP) techniques, which often rely on hand-crafted rules

(Taggart et al., 2018), grammatical assumptions, and feature engineering techniques such as parse trees or dictionaries, can be difficult to apply in this messy, irregular data regime.

In practice, machine learning models tend to make more use of structured fields such as medications and diagnoses that can be straightforwardly extracted from the EHR

(Pencina et al., 2016). Clinical notes are often ignored outright (Lipton et al., 2015; Choi et al., 2016; Esteban et al., 2016; Nickerson et al., 2016; Pham et al., 2016; Choi et al., 2017; Che et al., 2018; Cheng et al., 2016; Nguyen et al., 2017), and models that do use notes frequently reduce them to an unordered set of words (Marafino et al., 2018; Jacobson and Dalianis, 2016; Rajkomar et al., 2018) or topics (Miotto et al., 2016; Suresh et al., 2017). This ignores many subtleties of language and context, which can have a large impact on the meaning of the text in a note. For example, consider the snippet “no family hx of diabetes; discharge diagnosis of cva.” A word-level approach might represent this text as follows: [“cva”, “diabetes”, “diagnosis”, “discharge”, “family”, “hx”, “no”, “of”, “of”]. In this representation, the context phrases “no family hx” and “discharge diagnosis” are no longer associated with “diabetes” or “cva”, respectively. Such context is necessary to accurately determine which of the two conditions applies to the patient described in the note.

Recent advances in deep learning have led to major improvements in a wide variety of NLP applications

(Johnson et al., 2017; Devlin et al., 2018). Building on this work, we propose a model employing sequential, hierarchical, and pretraining (SHiP) techniques from deep NLP to improve EHR predictive models by automatically learning to extract relevant information from clinical notes. Specifically, our model employs a modified hierarchical attention network (Yang et al., 2016) to read clinical notes in sequence within the larger sequence of the patient’s medical history, preferentially attending to relevant portions of the text in each note. To enrich our model’s learned representation, we augment our training procedure with language model pretraining (Dai and Le, 2015): before optimizing the model for the prediction task, we train an auxiliary objective such that, for each word in the note, the notes-level model learns to predict the next word. To our knowledge, the effectiveness of language model pretraining has not been previously demonstrated for hierarchical classification models.

Our model reads clinical notes without assuming any particular layout, medical vocabulary, writing style or language rules in the text. By maintaining the sequential order of the words and abbreviations in the text, the model’s predictions can be informed by context that cannot be captured using keywords alone. We evaluate our model on standard classification tasks for EHRs, including identifying discharge diagnoses and predicting mortality risk, and compare the performance of this model against existing state-of-the-art baselines for these tasks (Rajkomar et al., 2018). We also evaluate the sensitivity of the model’s outputs to different phrases in the text using deep learning attribution methods (Sundararajan et al., 2017).

2 Cohort

We developed our models using patient data from the Medical Information Mart for Intensive Care (MIMIC-III) database (Johnson et al., 2016; Pollard and Johnson, 2016), a research dataset of medical records collected from critical care patients at the Beth Israel Deaconess Medical Center between 2001 and 2012. We represented patients’ medical histories as a time series according to the Fast Healthcare Interoperability Resources (FHIR) specification, as described in previous work (Rajkomar et al., 2018). The study cohort included all patients in MIMIC-III hospitalized for at least 24 hours.

Train & validation Test
Number of patients 40,511 4,439
Number of hospital admissions* 51,081 5,598
 Gender,
  Female 22,468 (44.0) 2,548 (45.5)
  Male 28,613 (56.0) 3,050 (54.5)
 Age, median (IQR) 62 (32) 62 (33)
 Hospital discharge service,
  General medicine 21,350 (41.8) 2,354 (42.1)
  Cardiovascular 10,965 (21.5) 1,175 (21.0)
  Obstetrics 7,123 (13.9) 803 (14.3)
  Cardiopulmonary 4,459 (8.7) 519 (9.3)
  Neurology 4,282 (8.4) 457 (8.2)
  Cancer 2,217 (4.3) 223 (4.0)
  Psychiatric 28 (0.1) 4 (0.1)
  Other 657 (1.3) 63 (1.1)
 Previous hospitalizations,
  None 40,362 (79.0) 4,415 (78.9)
  One 6,427 (12.6) 721 (12.9)
  Two to five 3,681 (7.2) 397 (7.1)
  Six or more 611 (1.2) 65 (1.2)
 Discharge location,
  Home 28,991 (56.8) 3,095 (55.3)
  Skilled nursing facility 6,878 (13.5) 794 (14.2)
  Rehab 5,757 (11.3) 653 (11.7)
  Other healthcare facility 3,830 (7.5) 448 (8.0)
  Expired 4,420 (8.7) 462 (8.3)
  Other 1,205 (2.4) 146 (2.6)
 Number of discharge ICD-9, median (IQR)** 9 (8) 9 (8)
* These numbers reflect the full set of admissions used for the mortality and ICD-9 prediction tasks. For primary CCS prediction, we excluded 1.3% of these admissions, for which the primary diagnosis corresponded to a non-billable ICD-9 code.
** Includes only billable ICD-9 codes.
Table 1: Descriptive statistics for patient cohort.

2.1 Data Extraction and Feature Choices

The set of features we extracted from the patient records comprised basic encounter information (admission type, status, and source), diagnosis and procedure codes, medication orders, quantitative observations (lab results and vital signs), and free-text clinical notes. For each continuous feature, values were standardized to Z-scores using the mean and standard deviation from the training set, with any outliers more than 10 standard deviations from the mean capped to a score of

.

3 Methods

3.1 Classification tasks

For each hospitalization, we developed models for the following tasks using the full history up to the specified time in the current admission (including all past hospitalizations).

Inpatient mortality prediction

Whether the patient died during the current admission (defined as a discharge disposition of “expired”). Predicted 24 hours after the patient’s admission.

Primary discharge diagnosis

The Clinical Classifications Software (CCS) (Elixhauser, 1996)

code associated with the patient’s primary discharge diagnosis. One of 236 mutually exclusive labels. Predicted at the moment of discharge.

All discharge diagnoses

The full set of ICD-9 (Slee, 1978) billing codes associated with the patient’s discharge diagnoses, including any of 6,448 possible labels. Predicted at the moment of discharge.

3.2 Model Architecture

All models in our experiments shared a core embedding scheme and top-level LSTM architecture described in previous work (Rajkomar et al., 2018)

, differing only in how they handled text from clinical notes. For each other discrete feature type in the patient timeline (e.g. diagnosis codes), individual tokens were “embedded,” or represented as a low-dimensional vector to be randomly initialized and then trained jointly with the model. To reduce sequence length, observations were grouped into coarse-grained timesteps of one hour in length, which we refer to as unordered “bags,” and embeddings or values for observations of the same feature within the same bag were averaged; additionally, all observations occurring prior to the most recent 1000 timesteps (i.e. more than 1000 hours before the time of prediction) were grouped into a single bag. The averaged embeddings for each discrete feature, as well as the standardized values for each continuous feature, were then concatenated into a single representation of each hourly timestep in the patient history.

We fed the embedded sequence to a long short-term memory (LSTM) network, a type of recurrent neural network that computes activations

at each timestep as a nonlinear function of the current input embedding and the previous hidden state and cell state , according to the following equations (Hochreiter and Schmidhuber, 1997):

(1)
(2)
(3)
(4)
(5)

Here,

denotes the sigmoid function,

denotes the hyperbolic tangent function, and denotes the Hadamard (elementwise) product. The final hidden state of the LSTM was passed through a feedforward output layer to generate predictions, with a sigmoid activation for mortality or ICD-9 prediction or a softmax activation for primary CCS prediction. We trained the model to minimize the cross-entropy loss on the ground-truth labels.

Figure 1: Top-level model architecture. The patient record is represented as a timeline with parallel streams of observations of different feature types (e.g. medications, procedures, clinical notes). Within each feature type, individual elements are embedded, and the embeddings are aggregated into hourly “bags” and averaged together. Bagged embeddings for all features are concatenated into a single vector at each hourly timestep, and this embedded sequence is fed into an LSTM to predict outcomes of interest. The representation for notes is computed differently between our baseline and proposed models (Figure 2).

We compared the following variants of this model in our experiments:

No notes

The free-text notes were not included, and the model was trained exclusively on the other elements of the record.

Bag-of-words (BOW)

Notes were included in the record, but treated just as any other discrete feature. We tokenized text at the word level, converted to lowercase, and stripped all punctuation. Individual words were then embedded, and all word embeddings within each hourly bag (which may contain several notes) were simply averaged together, ignoring word ordering and nearby context. We included a variant of this model that also uses bigrams, or pairs of adjacent words: the bigram strings were hashed to buckets (where is the unigram vocabulary size), and then embedded, bagged, and concatenated with the bagged unigrams.

Hierarchical LSTM

For each note, we embedded the individual words as in the bag-of-words model, but maintained the sequential order of embeddings within the note. We fed the embedded notes to a second LSTM, which reads the terms sequentially to generate a context-sensitive vector representation at each word. For the notes, we experimented with both unidirectional and bidirectional LSTMs (see Appendix): the latter processes the sequence in both the forward and reverse directions and concatenates the hidden states at each timestep, so that each output from the LSTM incorporates both previous and future context. We computed the final output vector for each note by aggregating the hidden states for each word according to a learned attention weighting, which places higher weight on the portions of the notes that are most important for the downstream prediction. Specifically, we used a slightly modified version of the hierarchical attention network Yang et al. (2016): the model computes the dot product of a query vector with each hidden state and normalizes via a softmax function to obtain the attention weighting over the sequence, augmented with an additional prior embedding vector and corresponding scalar bias weight (where , and are learned jointly with the model during training):

(6)

As before, outputs for individual notes within the same hourly bag were averaged together, and the bagged note vectors were concatenated with the remaining feature vectors for input into the record-level LSTM.

SHiP (Sequential, Hierarchical and Pretrained)

For this variant of the hierarchical LSTM, we augmented the standard training procedure with unsupervised language model pretraining (Dai and Le, 2015) for the note-level LSTM: before optimizing the cross-entropy loss, we trained an auxiliary objective such that, for each word in the note, the forward LSTM learned to predict the next word (and if bidirectional, the backward LSTM learned to predict the previous word). Note that while the mortality models were restricted to the first 24 hours of data from the current hospitalization (as well as any data from previous admissions) during training and evaluation, we pretrained these models over the full set of notes from the hospitalization up to discharge. We also found that the hierarchical mortality models performed best when, after pretraining, the notes LSTM weights were then frozen during the standard training phase.

For memory and performance reasons, in all hierarchical models we restricted the maximum amount of text used in the notes LSTM, keeping the most recent tokens per record (across all notes) and discarding any additional leading tokens. We tuned the level of truncation on the validation set, and found to be sufficient for training mortality models, but increased to for both diagnosis tasks and for pretraining. For comparison, we also trained variants of all models with notes as the only feature available to the model, excluding the other elements of the record.

Figure 2: Comparison of alternative notes representations. In both cases, each term in each note is represented as an embedding vector. The standard bag-of-words approach builds a representation from these individual word embeddings aggregated without regard to ordering. By contrast, our hierarchical model reads the embedded note sequentially, and the attention mechanism preferentially selects outputs corresponding to the parts of the note most relevant for prediction.

3.3 Attribution Methods

To compute attribution scores over the text of notes, we use the path-integrated gradients technique (Sundararajan et al., 2017). For clarity in these attributions, we ran a notes-only model over only the selected note, omitting the rest of the notes in the patient’s record. For a note with word embeddings for each word , we define the gradient attribution as the gradient of the model output for the class of interest (e.g. the patient’s true diagnosis) with respect to the word embedding :

(7)

The integrated gradients attribution, then, is the integral of these gradient attributions along the straight-line path between a baseline word embedding (which we choose to be the zero vector) and the learned word embeddings. If we approximate the integral using steps, this can be computed as:

(8)

In practice, we use steps. This score offers a first-order approximation of the effect of each input on the relevant model output. Integrated gradients are straightforward to implement and satisfy certain properties not guaranteed by standard gradient attribution (Sundararajan et al., 2017). The use of gradient-based attribution in text models is still an active area of research.

4 Results

4.1 Training and Evaluation Approach

We split our patient cohort according to patient ID into 80% train, 10% validation, and 10% test splits. For each task, we tuned learning rate, LSTM size, and dropout hyperparameters for the record-level LSTM using the full-feature BOW models. We transferred these hyperparameters to the SHiP models, and then performed additional tuning of the learning rate and of corresponding size and dropout parameters for the notes-level LSTM. Models were optimized using Adam

(Kingma and Ba, 2015) and dropout techniques applied included standard input and hidden-layer dropout (Srivastava et al., 2014), variational input and hidden-layer dropout and vocabulary dropout (Gal and Ghahramani, 2016), and zoneout (Krueger et al., 2016). For the SHiP models, all dropout techniques were applied during both pretraining and training. We used a Gaussian process bandit optimization algorithm (Desautels et al., 2014) to search for and select hyperparameters maximizing performance for each task on the validation set. Metrics for model selection included AUROC (for mortality and ICD9) and top-5 recall (for CCS). For predicting ICD9 codes, a multilabel task, we computed a weighted AUROC, where the AUROC for each label is averaged according to the label’s prevalence.

Following hyperparameter tuning, we trained each model five times from different random initializations for each task. We used early stopping to select the best model for each run according to validation set performance, and we report the mean and standard deviation of the test set performance over all five runs. We also compute and report the statistical significance of the change in performance between the best baseline and best hierarchical model for each task, and between the best hierarchical models with and without pretraining, using a two-tailed Welch’s t-test. All models were implemented in Tensorflow 1.12

(Abadi et al., 2016)

, and trained on Nvidia Tesla P100 GPUs. Evaluation metrics and statistical tests were calculated using scikit-learn 0.20

(Pedregosa et al., 2011).

4.2 Model Performance

Model Mortality Primary CCS All ICD-9
AUROC AUPRC Top-5 Recall F1 Weighted AUROC AUPRC
No notes - 0.869 (0.001) 0.449 (0.006) 0.793 (0.004) 0.521 (0.008) 0.869 (0.001) 0.297 (0.001)
Bag-of-words Unigrams (notes only) 0.832 (0.003) 0.383 (0.004) 0.849 (0.002) 0.591 (0.004) 0.880 (0.001) 0.328 (0.002)
Unigrams (all features) 0.880 (0.001) 0.479 (0.008) 0.842 (0.002) 0.587 (0.002) 0.880 (0.001) 0.321 (0.001)
Unigrams and bigrams (all features) 0.872 (0.002) 0.460 (0.005) 0.829 (0.001) 0.585 (0.003) 0.878 (0.001) 0.315 (0.001)
Hierarchical (without pretraining) Notes only 0.825 (0.003) 0.351 (0.003) 0.850 (0.001) 0.606 (0.003) 0.887 (0.002) 0.345 (0.005)
All features 0.876 (0.003) 0.471 (0.006) 0.812 (0.014) 0.555 (0.020) 0.869 (0.004) 0.291 (0.010)
SHiP Notes only 0.825 (0.004) 0.353 (0.005) 0.897* (0.003) 0.667* (0.006) 0.891 (0.001) 0.352 (0.001)
All features 0.882 (0.001) 0.479 (0.007) 0.887 (0.003) 0.660 (0.004) 0.889 (0.002) 0.332 (0.016)
* for difference compared to best hierarchical model without pretraining.
for difference compared to best bag-of-words model.
Table 2: Model performance results on the tasks of interest. Each model was trained five times from different random initializations on each task, and we report the mean (standard deviation) test set performance for each metric over the five runs. Best values for each metric are bolded.

Table 2 compares the performance of all model variants on the selected prediction tasks. The SHiP models significantly improved over the BOW baselines on the two diagnosis tasks ( under Welch’s t-test): for CCS prediction, the best SHiP model improved top-5 recall by 4.8 percentage points and F1 by 7.6 percentage points over the best BOW model; for ICD-9 prediction, weighted area under the ROC curve (AUROC) increased by 1.1 percentage points and area under the precision-recall curve (AUPRC) increased by 2.4 percentage points. For mortality prediction, we saw a negligible benefit from the SHiP architecture, with only 0.2 percentage point increase in AUROC and no change in AUPRC.

The SHiP models also improved over the corresponding hierarchical attention networks without pretraining. For mortality, pretraining the all-features model increased AUROC by 0.6 percentage points () and AUPRC by 0.8 percentage points (); for primary CCS, pretraining the notes-only model increased top-5 recall by 4.7 percentage points () and F1 by 6.1 percentage points (); for all ICD-9, pretraining the notes-only model increased weighted AUROC by 0.4 percentage points () and AUPRC by 0.7 percentage points ().

4.3 Qualitative Analysis

Figure 3 shows examples of path-integrated gradients attribution from a primary CCS model, over discharge summaries from different patients. We observe that the SHiP model frequently concentrates on just one or a few important phrases, even in very long notes. The choice of phrase is often informed by the nearby context: for example, we can see that the SHiP model is consistently most sensitive to the clinically-relevant words following the phrase “discharge diagnoses.” In fact, in the first and second notes, the patient’s diagnosis is restated elsewhere in the text in a less relevant context (e.g. stating that the patient has “no family history” of diabetes), but the model is sensitive only to the instance where the discharge context is made explicit. The bag-of-words model, by contrast, is incapable of making such contextual distinctions, and is generally more sensitive to key words and phrases throughout the text. We attempted a similar analysis using standard gradients attribution, but found the outputs to be noisier and less interpretable for both models.

Figure 3: Visualization of integrated gradients attribution over excerpts from patient discharge summaries. For each excerpt (patient’s primary diagnosis listed at left), the left column shows attribution from the bag-of-words baseline, and the right column attribution from the SHiP model. Below each word is the value of the attribution computed for that word, where a higher absolute value indicates greater importance. Red boxes highlight the patient’s stated diagnosis in the text, while blue boxes indicate relevant pieces of nearby context where applicable.

5 Discussion and Related Work

The results of this study demonstrate that SHiP, a novel combination of hierarchical modeling of clinical notes with language model pretraining, can improve discharge diagnosis classification over previous state-of-the-art models, with only minimal preprocessing of text. SHiP models process clinical notes in a way that is more sensitive to the context and structure of language compared to other common approaches, which often reduce notes to a set of keywords. Hierarchical recurrent networks (Yang et al., 2016; Chung et al., 2016; Hwang and Sung, 2017; Yu et al., 2016; Meng et al., 2017) and pretraining methods (Dai and Le, 2015; Devlin et al., 2018; Yang et al., 2019) have each individually proven successful in a wide variety of general NLP applications; here, we show the utility of these methods applied jointly, and specifically within a clinical context.

This work builds on recent literature on applying deep learning techniques to analysis of electronic health records data (Shickel et al., 2017). Many of the previous advances in deep learning for clinical NLP have relied on more standard convolutional or recurrent architectures rather than hierarchical approaches (Jagannatha and Yu, 2016; Vani et al., 2017; Mullenbach et al., 2018; Gehrmann et al., 2018; Chokwijitkul et al., 2018). However, a number of studies have recently begun applying hierarchical models to clinical text (Gao et al., 2017; Baumel et al., 2017; Samonte et al., 2018; Liu et al., 2018; Newman-Griffis and Zirikly, 2018) as well as other data modalities (Sha and Wang, 2017; Phan et al., 2018). These existing hierarchical text models do not utilize language model pretraining, generally relying on more limited techniques such as pretraining to approximate a weighted bag of word embeddings (Gao et al., 2017) or using pretrained word embeddings only (Liu et al., 2018; Newman-Griffis and Zirikly, 2018). Our results show that language model pretraining can substantially improve hierarchical models, and may in some cases be required to outperform strong bag-of-words baselines. This pretraining method may allow the hierarchical model to better learn long-term dependencies between words in a note and better use contextual information. The sequential processing of the note and path-integrated gradients also allow improved visualization of the parts of each note most relevant to a particular prediction. Future research might explore how this hierarchical pretraining framework could be extended to other features with sequential structure, such as vital signs.

This study also touches upon a broader question of when notes provide additional predictive value, relative to other parts of the medical record. Previous studies have found mixed evidence: for example, TF-IDF weighted unigrams extracted from notes were shown to have only moderate discriminative utility (and less than other elements of the clinical record) for predicting readmission (Walsh and Hripcsak, 2014); while applying an LSTM to notes represented as restricted bags of words was shown to improve performance on other tasks such as predicting diagnosis or length of stay (Boag et al., 2018). Our experiments also suggest a task-dependent pattern in the predictive value of notes. For all-cause mortality risk, we found that notes provided less predictive value compared especially to quantitative signals like labs or vitals. On the other hand, the SHiP models delivered clear improvements on the diagnosis classification tasks, likely because notes often contain rich diagnostic information – for example, discussions of differential diagnosis with qualifiers that cannot be easily captured in other forms. The difference was more pronounced for the primary CCS task; while ICD-9 prediction also benefited, the task of predicting several (possibly noisy) labels per patient is likely harder.

Our study has some important limitations. First, our analysis was limited to a single ICU patient population, and should be validated using other patient cohorts from other health centers. However, we note that nothing about our modeling approach is site-specific, and indeed its design should directly accommodate the particular note-writing habits at any institution. Second, although our proposed architecture can jointly model multiple data modalities, we found that our diagnosis models performed best with notes alone, and were prone to overfitting when provided additional features. Similarly, we hypothesized that adding bigrams to the BOW models might offer a simple approach to increasing the linguistic context available to the model, but found that this harmed generalization across tasks. Despite employing several common regularization techniques, we consistently observed a wider generalization gap (between training and validation set performance) when training the diagnosis models using more features. We hypothesize that it may be difficult to avoid overfitting to extraneous features and that our method may be better suited to larger patient cohorts, but this phenomenon invites further investigation. Third, our experiments included only a small subset of possible tasks; future research on a wider range of tasks might advance understanding of when and why notes are useful for prediction. Fourth, our experimental setup tests prediction at a single time point, when in fact in deployment the model will likely need to be adapted for continuous prediction. This remains an open area of research.

6 Conclusion

We demonstrate the effectiveness of techniques from deep NLP, particularly language model pretraining and hierarchical attention networks, for improved modeling of clinical notes. Our work provides a flexible and general approach that can be readily applied to clinical text from any source, for any modeling task where the unstructured information in the text is critical to understanding the outcome of interest.

Acknowledgments

We thank Nissan Hajaj and Xiaobing Liu for developing the core framework used to implement our models. We thank Gerardo Flores, Kathryn Rough, and Kun Zhang for providing assistance with our data processing and evaluation pipelines. We thank Kai Chen, Michael Howell, and Denny Zhou for their comments and feedback on this manuscript.

References

Appendix

AUROC AUPRC
Pretrained to 24 hours 0.881 (0.001) 0.478 (0.005)
Pretrained to discharge 0.882 (0.001) 0.479 (0.007)
Table 3: Comparison of test set results for all-feature SHiP mortality models pretrained to different time thresholds. Reporting mean (standard deviation) over five runs from random initialization.
Mortality Primary CCS All ICD-9
AUROC AUPRC Top-5 Recall F1 Weighted AUROC AUPRC
Unidirectional 0.895 0.490 0.888 0.651 0.887 0.342
Bidirectional 0.896 0.497 0.896 0.663 0.878 0.326
Table 4: Comparison of validation set results for SHiP models using unidirectional vs. bidirectional LSTMs.
Hyperparameters Mortality Primary CCS All ICD-9
BOW SHiP BOW SHiP BOW SHiP
Training Learning rate 0.00015 0.00011 0.00369 0.00067 0.00369 0.00048
Batch size 128 16 128 16 128 16
Pretraining steps N/A 30,000 N/A 30,000 N/A 40,000
Gradient clip norm 37.5 37.5 0.125 0.125 0.125 0.125
Variational vocabulary dropout* 0.001 0.229 0.273 0.396 0.273 0.273
Record LSTM Hidden units 379 518 518
Input dropout 0.466 0.246 0.246
Hidden dropout 0.045 0.136 0.136
Variational input dropout 0.034 0.071 0.071
Variational hidden dropout 0.090 0.122 0.122
Zoneout 0.268 0.437 0.437
Notes LSTM Bidirectional? N/A Yes N/A Yes N/A No
Hidden units 350 325 780
Input dropout 0.052 0.019 0.340
Hidden dropout 0.175 0.391 0.238
Variational input dropout 0.176 0.291 0.156
Variational hidden dropout 0.061 0.085 0.103
Zoneout 0.312 0.336 0.387
* The variational vocabulary dropout rate is shared across all features. For baseline models with bigrams, we increased the dropout rate on the notes vocabulary only to 0.75.
Table 5: Model hyperparameters. For the same task, all non-hierarchical models shared the BOW hyperparameters, and all hierarchical models shared the SHiP hyperparameters, except where noted. All models were trained using the Adam optimizer with default constant values: .