Code for AAAI 2018 accepted paper: "Beyond Sparsity: Tree Regularization of Deep Models for Interpretability"
The lack of interpretability remains a key barrier to the adoption of deep models in many applications. In this work, we explicitly regularize deep models so human users might step through the process behind their predictions in little time. Specifically, we train deep time-series models so their class-probability predictions have high accuracy while being closely modeled by decision trees with few nodes. Using intuitive toy examples as well as medical tasks for treating sepsis and HIV, we demonstrate that this new tree regularization yields models that are easier for humans to simulate than simpler L1 or L2 penalties without sacrificing predictive power.READ FULL TEXT VIEW PDF
Deep models have advanced prediction in many domains, but their lack of
The lack of interpretability remains a barrier to the adoption of deep n...
Deep learning is being adopted in settings where accurate and justifiabl...
One obstacle that so far prevents the introduction of machine learning m...
A key element of AutoML systems is setting the types of models that will...
When humans solve complex problems, they rarely come up with a decision
Deep Models, typically Deep neural networks, have millions of parameters...
Code for AAAI 2018 accepted paper: "Beyond Sparsity: Tree Regularization of Deep Models for Interpretability"
Deep models have become the de-facto approach for prediction in a variety of applications such as image classification (e.g. [Krizhevsky, Sutskever, and Hinton2012]) and machine translation (e.g. [Bahdanau, Cho, and Bengio2014, Sutskever, Vinyals, and Le2014]). However, many practitioners are reluctant to adopt deep models because their predictions are difficult to interpret. In this work, we seek a specific form of interpretability known as human-simulability. A human-simulatable model is one in which a human user can “take in input data together with the parameters of the model and in reasonable time step through every calculation required to produce a prediction” [Lipton2016]
. For example, small decision trees with only a few nodes are easy for humans to simulate and thus understand and trust. In contrast, even simple deep models like multi-layer perceptrons with a few dozen units can have far too many parameters and connections for a human to easily step through. Deep models for sequences are even more challenging. Of course, decision trees with too many nodes are also hard to simulate. Our key research question is: can we create deep models that are well-approximated by compact, human-simulatable models?
The question of creating accurate yet human-simulatable models is an important one, because in many domains simulatability is paramount. For example, despite advances in deep learning for clinical decision support (e.g.[Miotto et al.2016, Choi et al.2016, Che et al.2015]
), the clinical community remains skeptical of machine learning systems[Chen and Asch2017]. Simulatability allows clinicians to audit predictions easily. They can manually inspect changes to outputs under slightly-perturbed inputs, check substeps against their expert knowledge, and identify when predictions are made due to systemic bias in the data rather than real causes. Similar needs for simulatability exist in many decision-critical domains such as disaster response or recidivism prediction.
To address this need for interpretability, a number of works have been developed to assist in the interpretation of already-trained models. craven1996extracting craven1996extracting train decision trees that mimic the predictions of a fixed, pretrained neural network, but do not train the network itself to be simpler. Other post-hoc interpretations typically typically evaluate the sensitivity of predictions to local perturbations of inputs or the input gradient[Ribeiro, Singh, and Guestrin2016, Selvaraju et al.2016, Adler et al.2016, Lundberg and Lee2016, Erhan et al.2009]. In parallel, research efforts have emphasized that simple lists of (perhaps locally) important features are not sufficient: singh2016programs singh2016programs provide explanations in the form of programs; lakkaraju2016interpretable lakkaraju2016interpretable learn decision sets and show benefits over other rule-based methods.
These techniques focus on understanding already learned models, rather than finding models that are more interpretable. However, it is well-known that deep models often have multiple optima of similar predictive accuracy [Goodfellow, Bengio, and Courville2016], and thus one might hope to find more interpretable models with equal predictive accuracy. However, the field of optimizing deep models for interpretability remains nascent. ross2017right ross2017right penalize input sensitivity to features marked as less relevant. lei2016rationalizing lei2016rationalizing train deep models that make predictions from text and simultaneously highlight contiguous subsets of words, called a “rationale,” to justify each prediction. While both works optimize their deep models to expose relevant features, lists of features are not sufficient to simulate the prediction.
In this work, we take steps toward optimizing deep models for human-simulatability via a new model complexity penalty function we call tree regularization
. Tree regularization favors models whose decision boundaries can be well-approximated by small decision-trees, thus penalizing models that would require many calculations to simulate predictions. We first demonstrate how this technique can be used to train simple multi-layer perceptrons to have tree-like decision boundaries. We then focus on time-series applications and show that gated recurrent unit (GRU) models trained with strong tree-regularization reach a high-accuracy-at-low-complexity sweet spot that is not possible with any strength of L1 or L2 regularization. Prediction quality can be further boosted by training new hybrid models – GRU-HMMs – which explain the residuals of interpretable discrete HMMs via tree-regularized GRUs. We further show that the approximate decision trees for our tree-regularized deep models are useful for human simulation and interpretability. We demonstrate our approach on a speech recognition task and two medical treatment prediction tasks for patients with sepsis in the intensive care unit (ICU) and for patients with human immunodeficiency virus (HIV). Throughout, we also show that standalone decision trees as a baseline are noticeably less accurate than our tree-regularized deep models. We have released an open-source Python toolbox to allow others to experiment with tree regularization111https://github.com/dtak/tree-regularization-public.
While there is little work (as mentioned above) on optimizing models for interpretability, there are some related threads. The first is model compression, which trains smaller models that perform similarly to large, black-box models (e.g. [buciluǎ2006model, Hinton, Vinyals, and Dean2015, Balan et al.2015, Han et al.2015]). Other efforts specifically train very sparse networks via L1 penalties [Zhang, Lee, and Jordan2016] or even binary neural networks [Tang, Hua, and Wang2017, Rastegari et al.2016] with the goal of faster computation. Edge and node regularization is commonly used to improve prediction accuracy [Drucker and Le Cun1992, Ochiai et al.2017], and recently hu2016harnessingLogic hu2016harnessingLogic improve prediction accuracy by training neural networks so that predictions match a small list of known domain-specific first-order logic rules. Sometimes, these regularizations—which all smooth or simplify decision boundaries—can have the effect of also improving interpretability. However, there is no guarantee that these regularizations will improve interpretability; we emphasize that specifically training deep models to have easily-simulatable decision boundaries is (to our knowledge) novel.
We consider supervised learning tasks given datasets oflabeled examples, where each example (indexed by
) has an input feature vectors
and a target output vector. We shall assume the targets are binary, though it is simple to extend to other types. When modeling time series, each example sequence contains timesteps indexed by which each have a feature vector and an output . Formally, we write: and . Each value could be prediction about the next timestep (e.g. the character at time ) or some other task-related annotation (e.g. if the patient became septic at time ).
A multi-layer perceptron (MLP) makes predictions of the target via a function , where the vector represents all parameters of the network. Given a data set , our goal is to learn the parameters to minimize the objective
For binary targets , the logistic loss (binary cross entropy) is an effective choice. The regularization term can represent L1 or L2 penalties (e.g. [Drucker and Le Cun1992, Goodfellow, Bengio, and Courville2016, Ochiai et al.2017]) or our new regularization.
A recurrent neural network (RNN) takes as input an arbitrary length sequenceand produces a “hidden state” sequence of the same length as the input. Each hidden state vector at timestep represents a location in a (possibly low-dimensional) “state space” with dimensions: . RNNs perform sequential nonlinear embedding of the form in hope that the state space location is a useful summary statistic for making predictions of the target at timestep .
Many different variants of the transition function architecture have been proposed to solve the challenge of capturing long-term dependencies. In this paper, we use gated recurrent units (GRUs) [Cho et al.2014]
, which are simpler than other alternatives such as long short-term memory units (LSTMs)[Hochreiter and Schmidhuber1997]. While GRUs are convenient, any differentiable RNN architecture is compatible with our new tree-regularization approach.
Below we describe the evolution of a single GRU sequence, dropping the sequence index for readability. The GRU transition function produces the state vector from a previous state and an input vector , via the following feed-forward architecture:
The internal network nodes include candidate state gates , update gates and reset gates which have the same cardinalty as the state vector
. Reset gates allow the network to forget past state vectors when set near zero via the logistic sigmoid nonlinearity. Update gates allow the network to either pass along the previous state vector unchanged or use the new candidate state vector instead. This architecture is diagrammed in Figure 1.
The predicted probability of the binary label for time is a sigmoid transformation of the state at time :
Here, weight vector represents the parameters of this output layer. We denote the parameters for the entire GRU-RNN model as , concatenating all component parameters. We can train GRU-RNN time-series models (hereafter often just called GRUs) via the following loss minimization objective:
where again defines a regularization cost.
We now propose a novel tree regularization function for the parameters of a differentiable model which attempts to penalize models whose predictions are not easily simulatable. Of course, it is difficult to measure “simulatability” directly for an arbitrary network, so we take inspiration from decision trees. Our chosen method has two stages: first, find a single binary decision tree which accurately reproduces the network’s thresholded binary predictions given input . Second, measure the complexity of this decision tree as the output of . We measure complexity as the average decision path length—the average number of decision nodes that must be touched to make a prediction for an input example . We compute the average with respect to some designated reference dataset of example inputs from the training set. While many ways to measure complexity exist, we find average path length is most relevant to our notion of simulatability
. Remember that for us, human simulation requires stepping through every calculation required to make a prediction. Average path length exactly counts the number of true-or-false boolean calculations needed to make an average prediction, assuming the model is a decision tree. Total number of nodes could be used as a metric, but might penalize more accurate trees that have short paths for most examples but need more involved logic for few outliers.
Our true-average-path-length cost function is detailed in Alg. 1. It requires two subroutines, TrainTree and PathLength. TrainTree trains a binary decision tree to accurately reproduce the provided labeled examples . We use the DecisionTree module distributed in Python’s scikit-learn [Pedregosa et al.2011] with post-pruning to simplify the tree. These trees can give probabilistic predictions at each leaf. (Complete decision-tree training details are in the supplement.) Next, PathLength counts how many nodes are needed to make a specific input to an output node in the provided decision tree. In our evaluations, we will apply our average-decision-tree-path-length regularization, or simply “tree regularization,” to several neural models.
Training decision trees is not differentiable, and thus as defined in Alg. 1 is not differentiable with respect to the network parameters (unlike standard regularizers such as the L1 or L2 norm). While one could resort to derivative-free optimization techniques [Audet and Kokkolaras2016], gradient descent has been an extremely fast and robust way of training networks [Goodfellow, Bengio, and Courville2016].
A key technical contribution of our work is introducing and training a surrogate regularization function to map each candidate neural model parameter vector to an estimate of the average-path-length. Our approximate function is implemented as a standalone multi-layer perceptron network and is thus differentiable. Let vector of size denote the parameters of this chosen MLP approximator. We can train
where are the entire set of parameters for our model, is a regularization strength, and we assume we have a dataset of known parameter vectors and their associated true path-lengths: . This dataset can be assembled using the candidate vectors obtained while training our target neural model , as well as by evaluating for randomly generated . Importantly, one can train the surrogate function in parallel with our network. In the supplement, we show evidence that our surrogate predictor tracks the true average path length as we train the target predictor .
Even moderately-sized GRUs can have parameter vectors with thousands of dimensions. Our labeled dataset for surrogate training – —will only have one example from each target network training iteration. Thus, in early iterations, we will have only few examples from which to learn a good surrogate function . We resolve this challenge via augmenting our training set with additional examples: We randomly sample weight vectors and calculate the true average path length , and we also perform several random restarts on the unregularized GRU and use those weights in our training set.
A second challenge occurs later in training: as the model parameters
shift away from their initial values, those early parameters may not be as relevant in characterizing the current decision function of the GRU. To address this, for each epoch, we use examples only from the pastepochs (in addition to augmentation), where in practice,
is empirically chosen. Using examples from a fixed window of epochs also speeds up training. The supplement shows a comparison of the importance of these heuristics for efficient and accurate training—empirically, data augmentation for stabilizing surrogate training allows us to scale to GRUs with 100s of nodes. GRUs of this size are sufficient for many real problems, such as those we encounter in healthcare domains.
Typically, we use labeled pairs for surrogate training for toy datasets and for real world datasets. Optimization of our surrogate objective is done via gradient descent. We use Autograd to compute gradients of the loss in Eq. (5) with respect to , then use Adam to compute descent directions with step sizes set to 0.01 for toy datasets and 0.001 for real world datasets.
While time-series models are the main focus of this work, we first demonstrate tree regularization on a simple binary classification task to build intuition. We call this task the 2D Parabola problem, because as Fig. 2(a) shows, the training data consists of 2D input points whose two-class decision boundary is roughly shaped like a parabola. The true decision function is defined by . We sampled 500 input points uniformly within the unit square and labeled those above the decision function as positive. To make it easy for models to overfit, we flipped 10% of the points in a region near the boundary. A random 30% were held out for testing.
For the classifier, we train a 3-layer MLP with 100 first layer nodes, 100 second layer nodes, and 10 third layer nodes. This MLP is intentionally overly expressive to encourage overfitting and expose the impact of different forms of regularization: our proposed tree regularization and two baselines: an L2 penalty on the weights , and an L1 penalty on the weights . For each regularization function, we train models at many different regularization strengths chosen to explore the full range of decision boundary complexities possible under each technique.
For our tree regularization, we model our surrogate with a 1-hidden layer MLP with 25 units. We find this simple architecture works well, but certainly more complex MLPs could could be used on more complex problems. The objective in equation 1 was optimized via Adam gradient descent [Kingma and Ba2014]
using a batch size of 100 and a learning rate of 1e-3 for 250 epochs, and hyperparameters were set via cross validation using grid search (see supplement for full experimental details).
Fig. 2 (b) shows the each trained model as a single point in a 2D fitness space: the x-axis measures model complexity via our average-path-length metric, and the y-axis measures AUC prediction performance. These results show that simple L1 or L2 regularization does not produce models with both small node count and good predictions at any value of the regularization strength . As expected, large values for L1 and L2 only produce far-too-simple linear decision boundaries with poor accuracies. In contrast, our proposed tree regularization directly optimizes the MLP to have simple tree-like boundaries at high values which can still yield good predictions.
The lower panes of Fig. 2 shows these boundaries. Our tree regularization is uniquely able to create axis-aligned functions, because decision trees prefer functions that are axis-aligned splits. These axis-aligned functions require very few nodes but are more effective than L1 and L2 counterparts. The L1 boundary is more sharp, whereas the L2 is more round.
We now evaluate our tree-regularization approach on time-series models. We focus on GRU-RNN models, with some later experiments on new hybrid GRU-HMM models. As with the MLP, each regularization technique (tree, L2, L1) can be applied to the output node of the GRU across a range of strength parameters . Importantly, Algorithm 1 can compute the average-decision-tree-path-length for any fixed deep model given its parameters, and can hence be used to measure decision boundary complexity under any regularization, including L1 or L2. This means that when training any model, we can track both the predictive performance (as measured by area-under-the-ROC-curve (AUC); higher values mean better predictions), as well as the complexity of the decision tree required to explain each model (as measured by our average path length metric; lower values mean more interpretable models). We also show results for a baseline standalone decision tree classifier without any associated deep model, sweeping a range of parameters controlling leaf size to explore how this baseline trades off path length and prediction quality. Further details of our experimental protocol are in the supplement, as well as more extensive results with additional baselines.
We generated a toy dataset of sequences, each with timesteps. Each timestep has a data vector of 14 binary features and a single binary output label . The data comes from two separate HMM processes. First, a “signal” HMM generates the first 7 data dimensions from 5 well-separated states. Second, an independent “noise” HMM generates the remaining 7 data dimensions from a different set of 5 states. Each timestep’s output label is produced by a rule involving both the signal data and the signal hidden state: the target is 1 at timestep
only if both the first signal state is active and the first observation is turned on. We deliberately designed the generation process so that neither logistic regression withas features nor an RNN model that makes predictions from hidden states alone can perfectly separate this data.
We tested our approach on several real tasks: predicting medical outcomes of hospitalized septic patients, predicting HIV therapy outcomes, and identifying stop phonemes in English speech recordings. To normalize scales, we independently standardized features
Sepsis Critical Care: We study time-series data for 11 786 septic ICU patients from the public MIMIC III dataset [Johnson et al.2016]. We observe at each hour a data vector of 35 vital signs and lab results as well as a label vector of 5 binary outcomes. Hourly data measures continuous features such as respiration rate (RR), blood oxygen levels (paO), fluid levels, and more. Hourly binary labels include whether the patient died in hospital and if mechanical ventilation was applied. Models are trained to predict all 5 output dimensions concurrently from one shared embedding. The average sequence length is 15 hours. 7 070 patients are used in training, 1 769 for validation, and 294 for test.
HIV Therapy Outcome (HIV): We use the EuResist Integrated Database [Zazzi et al.2012] for 53 236 patients diagnosed with HIV. We consider 4-6 month intervals (corresponding to hospital visits) as time steps. Each data vector has 40 features, including blood counts, viral load measurements and lab results. Each output vector has 15 binary labels, including whether a therapy was successful in reducing viral load to below detection limits, if therapy caused CD4 blood cell counts to drop to dangerous levels (indicating AIDS), or if the patient suffered adherence issues to medication. The average sequence length is 14 steps. 37 618 patients are used for training; 7 986 for testing, and 7 632 for validation.
Phonetic Speech (TIMIT): We have recordings of 630 speakers of eight major dialects of American English reading ten phonetically rich sentences [Garofolo et al.1993]. Each sentence contains time-aligned transcriptions of 60 phonemes. We focus on distinguishing stop phonemes (those that stop the flow of air, such as “b” or “g”) from non-stops. Each timestep has one binary label indicating if a stop phoneme occurs or not. Each input has 26 continuous features: the acoustic signal’s Mel-frequency cepstral coefficients and derivatives. There are 6 303 sequences, split into 3 697 for training, 925 for validation, and 1 681 for testing. The average length is 614.
The major conclusions of our experiments comparing GRUs with various regularizations are outlined below.
Across tasks, we see that in the target regime of small decision trees (low average-path lengths), our proposed tree-regularization achieves higher prediction quality (higher AUCs). In the signal-and-noise HMM task, tree regularization (green line in Fig. 3(d)) achieves AUC values near 0.9 when its trees have an average path length of 10. Similar models with L1 or L2 regularization reach this AUC only with trees that are nearly double in complexity (path length over 25). On the Sepsis task (Fig. 4) we see AUC gains of 0.05-0.1 at path lengths of 2-10. On the TIMIT task (Fig. 4(a)), we see AUC gains of 0.05-0.1 at path lengths of 20-30. Finally, on the HIV CD4 blood cell count task in Fig. 4(b), we see AUC differences of between 0.03 and 0.15 for path lengths of 10-15. The HIV adherence task in Fig. 4(d) has AUC gains of between 0.03 and 0.05 in the path length range of 19 to 25 while at smaller paths all methods are quite poor, indicating the problem’s difficulty. Overall, these AUC gains are particularly useful in determining how to administer subsequent HIV therapies.
We emphasize that our tree-regularization usually achieves a sweet spot of high AUCs at short path lengths not possible with standalone decision trees (orange lines), L1-regularized deep models (red lines) or L2-regularized deep models (blue lines). In unshown experiments, we also tested elastic net regularization [Zou and Hastie2005], a linear combination of L1 and L2 penalities. We found elastic nets to follow the same trend lines as L1 and L2, with no visible differences. In domains where human-simulatability is required, increases in prediction accuracy in the small-complexity regime can mean the difference between models that provide value on a task and models that are unusable, either because performance is too poor or predictions are uninterpretable.
Across all tasks, the decision trees which mimic the predictions of tree-regularized deep models are small enough to simulate by hand (path length ) and help users grasp the model’s nonlinear prediction logic. Intuitively, the trees for our synthetic task in Fig. 3(a)-(c) decrease in size as the strength increases. The logic of these trees also matches the true labeling process: even the simplest tree (c) checks a relevant subset of input dimensions necessary to verify that both the first state and the first output dimension are active.
In Fig. 4, we show decision tree proxies for our deep models on two sepsis prediction tasks: mortality and need for ventilation. We consulted a clinical expert on sepsis treatment, who noted that the trees helped him understand what the models might be doing and thus determine if he would trust the deep model. For example, he said that using FiO, RR, CO and paO to predict need for mechanical ventilation (Fig. 3(d)) was sensible, as these all measure breathing quality. In contrast, the in-hospital mortality tree (Fig. 3(b)) predicts that some young patients with no organ failure have high mortality rates while other young patients with organ failure have low mortality. These counter-intuitive results led to hypotheses about how uncaptured variables impact the training process. Such reasoning would not be possible from simple sensitivity analyses of the deep model.
Finally, we have verified that the decision tree proxies of our tree-regularized deep models of the HIV task in Fig. 4(d) are interpretable for understanding why a patient has trouble adhering to a prescription; that is, taking drugs regularly as directed. Our clinical collaborators confirm that the baseline viral load and number of prior treatment lines, which are prominent attributes for the decisions in Fig. 4(d), are useful predictors of a patient with adherence issues. Several medical studies [Langford, Ananworanich, and Cooper2007, Socías et al.2011] suggest that patients with higher baseline viral loads tend to have faster disease progression, and hence have to take several drug cocktails to combat resistance. Juggling many drugs typically makes it difficult for these patients to adhere as directed. We hope interpretable predictive models for adherence could help assess a patient’s overall prognosis [Paterson et al.2000] and offer opportunities for intervention (e.g. with alternative single-tablet regimens).
Across datasets, we find that each tree-regularized deep time-series model has predictions that agree with its corresponding decision tree proxy in about 85-90% of test examples. Table 1 shows exact fidelty scores for each dataset. Thus, the simulatable paths of the decision tree will be trustworthy in a majority of cases.
While our tree-regularized GRU with 10 states takes 3977 seconds per epoch on TIMIT, a similar L2-regularized GRU takes 2116 seconds per epoch. Thus, our new method has cost less than twice the baseline even when the surrogate is serially computed. Because the surrogate will in general be a much smaller model than the predictor , we expect one could get faster per-epoch times by parallelizing the creation of training pairs and the training of the surrogate . Additionally, 3977 seconds includes the time needed to train the surrogate. In practice, we do this sparingly, only once every 25 epochs, yielding an amortized per-epoch cost of 2191 seconds (more runtime results are in the supplement).
When tree regularization is strong (high ), the decision trees trained to match the predictions of deep models are stable. For both signal-and-noise and sepsis tasks, multiple runs from different random restarts have nearly identical tree shape and size, perhaps differing by a few nodes. This stability is crucial to building trust in our method. On the signal-and-noise task (), 7 of 10 independent runs with random initializations resulted in trees of exactly the same structure, and the others closely resembled those sharing the same subtrees and features (more details in supplement).
|SEPSIS (In-Hospital Mortality)||0.81|
|SEPSIS (90-Day Mortality)||0.88|
|SEPSIS (Mech. Vent.)||0.90|
|SEPSIS (Median Vaso.)||0.92|
|SEPSIS (Max Vaso.)||0.93|
|HIV (CD4 below 200)||0.84|
|HIV (Therapy Success)||0.88|
|HIV (Poor Adherence)||0.90|
|HIV (AIDS Onset)||0.93|
So far, we have focused on regularizing standard deep models, such as MLPs or GRUs. Another option is to use a deep model as a residual on another model that is already interpretable: for example, discrete HMMs partition timesteps into clusters, each of which can be inspected, but its predictions might have limited accuracy. In Fig. 6, we show the performance of jointly training a GRU-HMM, a new model which combines an HMM with a tree-regularized GRU to improve its predictions (details and further results in the supplement). Here, the ideal path length is zero, indicating only the HMM makes predictions. For small average-path-lengths, the GRU-HMM improves the original HMM’s predictions and has simulatability gains over earlier GRUs. On the mechanical ventilation task, the GRU-HMM requires an average path length of only 28 to reach AUC of 0.88, while the GRU alone with the same number of states requires a path length of 60 to reach the same AUC. This suggests that jointly-trained deep residual models may provide even better interpretability.
We have introduced a novel tree-regularization technique that encourages the complex decision boundaries of any differentiable model to be well-approximated by human-simulatable functions, allowing domain experts to quickly understand and approximately compute what the more complex model is doing. Overall, our training procedure is robust and efficient; future work could continue to explore and increase the stability of the learned models as well as identify ways to apply our approach to situations in which the inputs are not inherently interpretable (e.g. pixels in an image).
Across three complex, real-world domains – HIV treatment, sepsis treatment, and human speech processing – our tree-regularized models provide gains in prediction accuracy in the regime of simpler, approximately human-simulatable models. Future work could apply tree regularization to local, example-specific approximations of a loss [Ribeiro, Singh, and Guestrin2016] or to representation learning tasks (encouraging embeddings with simple boundaries). More broadly, our general training procedure could apply tree-regularization or other procedure-regularization to a wide class of popular models, helping us move beyond sparsity toward models humans can easily simulate and thus trust.
MW is supported by the U.S. National Science Foundation. MCH is supported by Oracle Labs. SP is supported by the Swiss National Science Foundation project 51MRP0_158328. The authors thank the EuResist Network for providing HIV data for this study, and thank Matthieu Komorowski for the preprocessed sepsis data [Raghu et al.2017]. Computations were supported by the FAS Research Computing Group at Harvard and sciCORE (http://scicore.unibas.ch/) scientific computing core facility at University of Basel.
Improving generalization performance using double backpropagation.IEEE Transactions on Neural Networks 3(6):991–997.
Continuous state-space models for optimal sepsis treatment-a deep reinforcement learning approach.In Machine Learning for Healthcare Conference.
Our average path length function for determining the complexity of a deep model with parameters – defined in the main paper in Alg. 1 – assumes that we have a robust, black-box way to train binary decision-trees called TrainTree given a labeled dataset . For this we use the DecisionTree module distributed in Python’s sci-kit learn, which optimizes information gain with Gini impurity. The specific syntax we use (for reproducibility) is:
tree = DecisionTree(min_sample_count=5) tree.fit(x_train, y_train) tree = prune_tree(tree, x_valid, y_valid)
The provided keyword options force the tree to have at least 5 examples from the training set in every leaf. We found that tuning hyperparameters of the TrainTree subprocedure, such as the minimum size of a leaf node, to be important for making useful trees.
Generally, the runtime cost of sklearn’s fitting procedure scales superlinearly with the number of examples and linearly with the number of features – a total complexity of . In practice, we found that with examples, features, tree construction takes 15.3 microseconds.
The pruning procedure is a heuristic to create simpler trees, summarized in algorithm 2. After TrainTree delivers a working decision tree, we iterative propose removing each remaining leaf node, accepting the proposal if the squared prediction error on a validation set improves. This pruning removes sub-trees that don’t generalize to unseen data.
Fig. A.1 shows that our surrogate predictor tracks the true average path length as we train the target predictor on several different datasets.
In Fig. A.2, we show sample learning curves for variations of methods for approximating the average path length (also called “node count”) in a decision tree. In blue is the true value. Each of the other 3 lines use the same surrogate model: an MLP with 25 hidden nodes. Increasing its capacity too much, i.e. 100 hidden nodes, leads to overfitting where the surrogate is able to predict the average path length extremely well for a small number of iterations, while the performance quickly decays. With an MLP of the right capacity, four additional tricks: (1) weight augmentation, (2) random restarts with an unregularized model, (3) fixed window of data, and (4) surrogate retraining greatly improve the accuracy of the average path length predictions.
Normally, if our differentiable model is a GRU, we compile examples using the GRU weights at every batch and calculate the true average path length. This dataset is used to train the surrogate model. If examples are very sparse, surrogate predictions may be unstable. Augmentation addresses this by randomly sampling weight vectors and computing the average path length to artificially create a larger dataset. Early epochs are especially problematic when it comes to lacking data. In addition to augmentation, we use random restarts to separately train unregularized GRUs (each with different weight initializations) to grow a dataset of weight vectors prior to training the regularized model.
As the GRU parameters take steps away from their initial values, our examples from those early epochs no longer describe the current state of the model. Retraining and a fixed window of data address this by re-learning the surrogate function at a fixed frequency using examples only from the last epochs. In practice, both the augmentation size, the retraining frequency, and are functions of the learning rate and the dataset size. See table B.1 for exact numbers.
See table B.1
for model hyperparameters for each dataset. For standard recurrent models such as HMM or GRU, the decision trees were trained on the input data and the predictions of the model’s output node. For our deep residual GRU-HMM, the decision trees were trained on the predictions on the GRU’s output node only. For both synthetic and real-world datasets, our surrogate to the tree loss is a multilayer perceptron with 1 hidden layer of 25 nodes. For each dataset, when we investigated several regularization strengths (), we initialize the model weights using the same random seed. We use the Adam algorithm [Kingma and Ba2014] for all optimization.
|Dataset||Total Num. Sequences||Avg. seq. length||Learning Rate||Batch size||Minimum Leaf Sample||Post-pruned||Epochs (Model)||Epochs (Surrogate)||Retraining Freq.|
|HIV||53 236||14||1e-3||256||1 000||Y||300||5000||25||100|
|SEPSIS||11 786||15||1e-3||256||1 000||Y||300||5000||25||100|
|TIMIT||6 303||614||1e-3||256||5 000||Y||200||5000||25||100|
The training data consists of 2D input points whose two-class decision boundary is roughly shaped like a parabola. The true decision function is defined by . We sampled all 200 input points uniformly within the unit square and labeled those above the decision function as positive. To add randomness, we flipped 10% of the points in the region near the boundary between and .
Tested values of regularization strength parameter : 0.1, 0.5, 1, 5, 10, 25, 50, 75, 100, 250, 500, 750, 1 000, 2 500, 5 000, 7 500, 10 000, 25 000, 50 000, 75 000, 100 000
The transition and emission matrices describing the generative process used to create the signal-and-noise HMM are shown in Fig. B.1. The output at every timestep is created by concatenating a one-hot vector of an emitted state and the 7-dimensional binary input vector. We emphasize that to output 1, the HMM must be in state 1 and the first input feature must be 1.
With synthetic datasets, we explore (1, 5, 6, 10, 15, 20) GRU nodes, (5, 6, 20) HMM states, and GRU-HMMs with 5 HMM states and (1, 5, 10, 15) GRU nodes.
We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75, 100) GRU nodes, (5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75, 100) HMM states, and GRU-HMMs with (5, 10, 25, 50) HMM states and (1, 5, 10, 25, 50) GRU nodes. The input features are z-scored prior to training.
We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) GRU nodes, (5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) HMM states, and GRU-HMMs with (5, 10, 25) HMM states and (1, 5, 10, 25, 50) GRU nodes.
We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) GRU nodes, (5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) HMM states, and GRU-HMMs with (5, 10, 25) HMM states and (1, 5, 10, 25, 50) GRU nodes. Like Sepsis, the input features are z-scored prior to training.
For signal-to-noise HMM, Sepsis, and TIMIT, we first show expanded versions of the fitness trace plots and the tree visualizations. For Sepsis and HIV, we show the additional output dimensions not in the paper.
We also include tables of the test AUC performance for our synthetic and real data sets over a vast array of parameter settings (GRU node counts, HMM state counts, regularization strengths). Consistent with the common wisdom of training deep models, we found that larger models, with regularization, tended to perform the best.
|Model||In-Hospital Mortality||90-Day Mortality||Mechanical Ventilation||Median Vasopressor||Max Vasopressor||Total Average Path Length||Parameter Count|
|hmm (15)||0.7216||0.7282||0.8188||0.7346||0.7341||61.832||1 365|
|hmm (20)||0.7233||0.7350||0.8218||0.7371||0.7364||62.353||1 920|
|hmm (25)||0.7147||0.7321||0.8089||0.7313||0.7310||63.415||2 525|
|hmm (30)||0.7164||0.7297||0.8099||0.7316||0.7311||65.164||3 180|
|hmm (35)||0.7177||0.7237||0.8095||0.7201||0.7195||65.474||3 885|
|hmm (50)||0.7267||0.7357||0.8373||0.7335||0.7328||66.317||6 300|
|hmm (75)||0.7254||0.7361||0.8059||0.7434||0.7430||72.553||11 325|
|hmm (100)||0.7294||0.7354||0.8129||0.7408||0.7403||80.415||17 600|
|gru (10)||0.7488||0.7445||0.8892||0.7983||0.7979||58.102||1 440|
|gru (15)||0.7529||0.7450||0.8912||0.8020||0.8021||61.025||2 385|
|gru (20)||0.7535||0.7497||0.8887||0.8018||0.8017||61.214||3 480|
|gru (25)||0.7578||0.7486||0.8902||0.8113||0.8114||62.029||4 725|
|gru (30)||0.7602||0.7508||0.8927||0.8063||0.8061||72.854||6 120|
|gru (35)||0.7522||0.7483||0.8900||0.8095||0.8091||74.091||7 665|
|gru (50)||0.7431||0.7390||0.8895||0.8054||0.8051||76.543||13 200|
|gru (75)||0.7408||0.7239||0.8837||0.8006||0.8000||87.422||25 425|
|gru (100)||0.7325||0.7273||0.8781||0.7977||0.7975||94.161||41 400|
|grutree (100/0.01)||0.7276||0.7314||0.8776||0.7873||0.7867||91.797||41 400|
|grutree (100/1.0)||0.7147||0.7040||0.8741||0.7812||0.7810||82.019||41 400|
|grutree (100/8.0)||0.7232||0.7203||0.8763||0.7845||0.7840||73.767||41 400|
|grutree (100/20.0)||0.7123||0.7085||0.8733||0.7813||0.7813||65.035||41 400|
|grutree (100/70.0)||0.7360||0.7376||0.8813||0.7988||0.7986||61.012||41 400|
|grutree (100/300.0)||0.7210||0.7197||0.8681||0.7676||0.7678||54.177||41 400|
|grutree (100/2 000.0)||0.7230||0.7167||0.8335||0.7616||0.7619||48.206||41 400|
|grutree (100/5 000.0)||0.6546||0.6552||0.6752||0.6668||0.6530||26.085||41 400|
|grutree (100/7 000.0)||0.6063||0.6554||0.6565||0.6230||0.6138||20.214||41 400|
|grutree (100/8 000.0)||0.5298||0.5242||0.5025||0.5026||0.5057||13.383||41 400|
|gruhmm (1/10)||0.4007||0.6295||0.4730||0.7418||0.7419||61.041||1 517|
|gruhmm (1/25)||0.4019||0.6207||0.4773||0.7353||0.7352||65.955||4 802|
|gruhmm (1/50)||0.3999||0.6162||0.4772||0.7120||0.7121||70.534||13 277|
|gruhmm (5/5)||0.7430||0.7372||0.8798||0.8009||0.8006||47.639||1 050|
|gruhmm (5/10)||0.7408||0.7320||0.8819||0.7991||0.7988||63.627||1 845|
|gruhmm (5/25)||0.7365||0.7279||0.8776||0.7955||0.7952||68.215||5 130|
|gruhmm (5/50)||0.7222||0.7107||0.8660||0.7814||0.7811||71.572||13 605|
|gruhmm (10/5)||0.7468||0.7467||0.8949||0.8098||0.8097||50.902||1 505|
|gruhmm (10/10)||0.7490||0.7478||0.8958||0.8098||0.8096||63.522||2 300|
|gruhmm (10/25)||0.7422||0.7407||0.8916||0.8055||0.8054||70.919||5 585|
|gruhmm (10/50)||0.7254||0.7221||0.8824||0.7903||0.7903||71.297||14 060|
|gruhmm (25/5)||0.7580||0.7568||0.8941||0.8236||0.8235||51.794||3 170|
|gruhmm (25/10)||0.7592||0.7563||0.8945||0.8225||0.8225||64.223||3 965|
|gruhmm (25/25)||0.7525||0.7508||0.8912||0.8186||0.8184||72.480||7 250|
|gruhmm (25/50)||0.7604||0.7583||0.8954||0.8106||0.8103||79.127||11 025|
|gruhmm (50/5)||0.7655||0.7592||0.9006||0.8228||0.8226||64.229||6 945|
|gruhmm (50/10)||0.7648||0.7568||0.9003||0.8220||0.8219||69.281||7 740|
|gruhmm (50/25)||0.7600||0.7555||0.8981||0.8205||0.8203||85.503||11 025|
|gruhmm (50/50)||0.7412||0.7373||0.8910||0.8056||0.8055||101.637||19 500|
|gruhmmtree (50/50/0.5)||0.7432||0.7492||0.879||0.7854||0.7849||84.188||19 500|
|gruhmmtree (50/50/20.0)||0.7435||0.747||0.8826||0.7914||0.7906||77.815||19 500|
|gruhmmtree (50/50/50.0)||0.7384||0.7548||0.8914||0.7922||0.7918||71.719||19 500|
|gruhmmtree (50/50/200.0||0.747||0.7502||0.8767||0.7832||0.7824||69.715||19 500|
|gruhmmtree (50/50/300.0)||0.7539||0.7623||0.8942||0.8092||0.8091||66.9||19 500|
|gruhmmtree (50/50/600.0||0.7435||0.7453||0.8821||0.7909||0.7905||63.703||19 500|
|gruhmmtree (50/50/1 000.0)||0.7575||0.7502||0.8739||0.7882||0.7873||60.949||19 500|
|gruhmmtree (50/50/3 000.0)||0.7396||0.7484||0.8926||0.8013||0.8011||54.751||19 500|
|gruhmmtree (50/50/4 000.0)||0.7432||0.7511||0.8915||0.802||0.8024||44.868||19 500|
|gruhmmtree (50/50/7 000.0)||0.7308||0.7477||0.8813||0.7881||0.7882||27.836||19 500|
|gruhmmtree (50/50/9 000.0)||0.7132||0.7319||0.8261||0.7301||0.7299||0.0||19 500|
|Model||Poor Adherence||Mortality||CD4 Count 200||Therapy Success||Total Average Path Length||Parameter Count|
|grutree (100/2 000.0)||0.7030||0.8169||0.6342||0.6627||49.839||54700|
|grutree (100/5 000.0)||0.6549||0.7582||0.6142||0.6352||23.895||54700|
|grutree (100/7 000.0)||0.6167||0.7524||0.5740||0.5634||15.283||54700|
|grutree (100/8 000.0)||0.5874||0.7412||0.5003||0.5027||7.391||54700|
|gruhmmtree (50/50/1 000.0)||0.7375||0.8951||0.8739||0.7882||48.247||30750|
|gruhmmtree (50/50/4 000.0)||0.7242||0.8461||0.8515||0.8030||14.868||30750|
|gruhmmtree (50/50/7 000.0)||0.7280||0.8462||0.8313||0.7484||1.836||30750|
|Model||AUC||Average Path Length||Parameter Count|
|hmm (25)||0.9129||57.602||1 975|
|hmm (50)||0.9189||63.752||5 200|
|hmm (75)||0.9251||71.473||9 675|
|gru (10)||0.9509||60.079||1 130|
|gru (25)||0.9547||62.051||3 950|
|gru (50)||0.9578||64.957||11 650|
|gru (75)||0.9620||68.998||23 100|
|gruhmm (5/10)||0.9575||57.6199||1 130|
|gruhmm (5/25)||0.9603||59.9925||2 465|
|gruhmm (10/5)||0.9626||57.0652||1 425|
|gruhmm (10/10)||0.9641||60.7877||1 770|
|gruhmm (10/25)||0.9651||61.0018||3 105|
|gruhmm (25/5)||0.9635||57.5288||4 245|
|gruhmm (25/10)||0.9657||60.5212||4 590|
|gruhmm (25/25)||0.9663||65.0161||5 925|
|gruhmm (50/5)||0.9676||62.2378||11 945|
|gruhmm (50/10)||0.9679||65.1191||12 290|
|gruhmm (50/25)||0.9685||67.4301||13 625|
|grutree (75/0.01)||0.9517||66.2801||23 100|
|grutree (75/0.1)||0.9466||62.4316||23 100|
|grutree (75/0.5)||0.9367||60.8764||23 100|
|grutree (75/2.0)||0.9311||58.3659||23 100|
|grutree (75/5.0)||0.9302||55.7588||23 100|
|grutree (75/10.0)||0.9288||46.6616||23 100|
|grutree (75/100.0)||0.8911||40.1123||23 100|
|grutree (75/500.0)||0.8998||28.4240||23 100|
|grutree (75/700.0)||0.8628||25.136||23 100|
|grutree (75/800.0)||0.7471||22.6671||23 100|
|grutree (75/1 000.0)||0.7082||17.1523||23 100|
|grutree (75/6 000.0)||0.5441||11.1108||23 100|
|grutree (75/7 000.0)||0.5088||8.9910||23 100|
|gruhmmtree (50/25/0.1)||0.9507||69.1110||13 625|
|gruhmmtree (50/25/1.0)||0.9465||67.5773||13 625|
|gruhmmtree (50/25/6.0)||0.9515||65.1494||13 625|
|gruhmmtree (50/25/20.0)||0.9449||64.0072||13 625|
|gruhmmtree (50/25/30.0)||0.9482||62.5406||13 625|
|gruhmmtree (50/25/70.0)||0.9460||58.0111||13 625|
|gruhmmtree (50/25/100.0)||0.9470||51.2417||13 625|
|gruhmmtree (50/25/500.0)||0.9401||42.1882||13 625|
|gruhmmtree (50/25/700.0)||0.9352||40.1281||13 625|
|gruhmmtree (50/25/1 000.0)||0.9390||38.0072||13 625|
|gruhmmtree (50/25/3 000.0)||0.9280||25.9120||13 625|
|gruhmmtree (50/25/4 000.0)||0.9311||21.7170||13 625|
|gruhmmtree (50/25/7 000.0)||0.9290||10.1122||13 625|
|gruhmmtree (50/25/9 000.0)||0.9134||1.0563||13 625|
|gruhmmtree (50/25/10 000.0)||0.9125||0.0000||13 625|
For our purposes, Hidden Markov Models (HMMs) can be viewed as stochastic RNNs which can be interpreted as probabilistic generative models. In this work, we consider an HMM to generate a latent variable sequence
via a Markov chain, where each latent indicates one ofpossible discrete states: . This state sequence is then used to jointly produce the “data” and “outcomes”
observed at each timestep. The joint distribution overfactorizes as:
where is a transition matrix such that , is the initial state distribution, are the emission parameters that generate data. We can then apply the same objective as above for training.
We now consider an additional model, the GRU-HMM, designed for interpretability. The idea is to use a GRU to to model the residual errors when predicting the binary target via the HMM belief states. We can further penalize the complexity of the GRU predictions via our tree regularization, so that higher-quality predictions do not come at the price of a much less interpretable model.
We train the deep residual model on the same suite of synthetic and real world datasets. See Tables C.1, C.2, C.4 for a comparison of GRU-HMM with vanilla GRU and HMM models under different regularization and expressiveness parameters. We can see that across the datasets, deep residual models perform around 1% better than their vanilla equivalents with roughly the same number of model parameters.
By nature of being a residual model, decision trees were trained only on the GRU output node, leaving the HMM unconstrained. See Figure D.1 for a pictoral representation. Similar to what we did for GRU models, figures 0(b), D.2 compare model performance as the parameter for L1, L2, and Tree regularization increase. We can see a similar albeit less pronounced effect where Tree regularization dominates other methods in low node count regions. It is important to notice the range of the AUC axis in these figures, where the worst the residual model can performance is the HMM-only AUC. Figure D.3 show the regularized trees produced by the GRU-HMM. Although they share some structure with Figure C.4, there are important distinctions that encourage us to conclude that the GRU in a residual models performs a different role than when trained alone.
Table E.1 shows the wall time for training one epoch of each of the models presented in this paper using each of the datasets. Please note that the wall times for GRU-TREE and GRU-HMM-TREE include the cost of surrogate training. If the retraining frequency is small, then the amortized cost should be small.
|Dataset||Model||Epoch Time (Sec.)|
In the paper, we noted that decision trees are stable over multiple run. Here, we show that using the signal-and-noise HMM dataset, 10 independent runs with random initializations and produce either the same or comparable trees. Additionally, we show that with weak regularization (), the variability of the learned decision trees is high. Figures F.1, F.2 include examples of such trees on the signal-and-noise dataset. Similar results are found for real-world datasets.
. With low regularization, the variance in tree size and shape is high.