Deep models have become the de-facto approach for prediction in many applications like image classification (e.g. ) and machine translation (e.g. [5, 50]) and further seem poised to advance prediction in real-world domains [36, 21, 19]. However, many practitioners still are reluctant to adopt deep models because their predictions are difficult to interpret. Without interpretability, humans are unable to incorporate their domain knowledge and effectively audit predictions.
In this work, we shall seek a specific form of interpretability known as human-simulability. A human-simulable model is one in which a human user can “take in input data together with the parameters of the model and in reasonable time step through every calculation required to produce a prediction” 
. For example, small decision trees with only a few nodes are easy for humans to simulate and thus understand. Human-simulability is valuable in many domains. In particular, despite advances in deep learning for clinical decision support (e.g.[36, 11, 8]
), the clinical community remains skeptical (and rightfully so) of machine learning systems. The black box nature of neural networks prevents the checks-and-balances and quality control that we expect from healthcare providers. Meanwhile, a simulable model would enable clinicians to audit predictions easily: they can manually inspect changes to outputs under perturbed inputs, check substeps against their expert knowledge, and reason about external factors influencing prediction like systemic bias in the data. Similar needs for simulability exist in many decision-critical domains such as disaster response or recidivism prediction.
Despite the appeal and need for human-simulability, many popular models are not simulable. Even simple deep models like multi-layer perceptrons with a few dozen units can have far too many parameters and connections for a human to easily step through (successive matrix multiplications quickly becomes difficult to think about). Richer families of neural networks such as those for sequences are essentially impossible for humans to simulate. However, with added non-linearities and many more free parameters, these rich families often allow for significantly more accurate predictions than a small decision tree. Thus, the primary question we consider in this work is the following:Is it possible for a powerful model such as a deep network to be human-simulable, or at least frequently human-simulable?
Simulability is a rather strict definition for interpretability as it requires full transparency in prediction. As such, current work on the interpretability of black-box models struggle to balance being both simulable and faithful to the model. For instance, craven1996extracting craven1996extracting train decision trees that mimic the predictions of a fixed, pre-trained neural network. Other post-hoc interpretations typically evaluate the sensitivity of predictions to local perturbations of inputs or the input gradient [45, 48, 1, 32, 16]. While the post-hoc interpretations come in many sophisticated forms— others include , who uses programs to explain a model’s predictions as a post-hoc step, and , who learn decision sets based on a learned model— it is difficult to simplify the complex logic of an unregularized neural network to a simulable (simple) tree, set, or program. As a result, many of these methods only explain local behavior or a lower resolution (noisy) depiction of global logic. In general, the problem of distilling the decision function of a trained and unregularized neural network to a simple family of decision functions is somewhat ill-posed: unregularized neural networks have no incentive to be simulable or any other notion of human-interpretability. Instead, they will learn complex decision boundaries fit to succeed at the target task. Trying to enforce interpretability post-hoc must understandably make strong assumptions that over-simplify the model’s logic.
In contrast, we begin with the observation that since it is well-known that deep models often have multiple optima of similar predictive accuracy  one might hope to directly find “more interpretable" minima with equal predictive accuracy. In other words, if we consider interpretability from the very start i.e. add an “interpretability term" in the objective function, it might be possible to train neural networks to be both performant and simulable. In general however, the field of optimizing deep models for interpretability remains largely nascent. In this vein, ross2017right ross2017right penalize input sensitivity to features marked as less relevant, while lei2016rationalizing lei2016rationalizing train deep models that make predictions from text and simultaneously highlight contiguous subsets of words, called a “rationale,” to justify each prediction. Unfortunately, while both works optimize deep models to expose relevant features, these lists of features alone are not sufficient to simulate the prediction. We draw a stark distinction between explanation and simulation: the former may describe interpretable features whereas the latter requires defining both features and a procedure for translating them into output. In the following, we introduce two contributions: we first discuss how to optimize deep models to expose prediction logic (not just features) using decison trees, and second, how to generalize this method to incorporate human prior knowledge.
To optimize for interpretability, we must define an objective function that finds deep models that are both accurate and simulable. To do this, we introduce the notion of tree-regularization.
Specifically, we define a novel model-complexity penalty function that favors model optima whose decision boundaries can be well-approximated by small decision trees. In effect, this penalizes models that would require many calculations to simulate predictions. Similar to many popular regularizers such as L2 or L1, the tree regularizer is a function on the weights of the neural network. Several of our technical contributions surround making this regularizer differentiable such that it is compatible with stochastic gradient descent. Experimentally, we first exemplify how this technique can be used to train simple multi-layer perceptrons to have tree-like decision boundaries. We then focus on time-series applications and show that gated recurrent unit (GRU) models trained with strong tree-regularization reach a high-accuracy-at-low-complexity sweet spot that is not possible with any strength of L1 or L2 regularization. Furthermore, we will show that the decision trees (produced during training) can be used as tools for human simulation – they act as distillations of the deep model and can be give to domain experts. Choosing several real world applications, we demonstrate these features of our approach on a speech recognition task and two medical treatment prediction tasks for patients with sepsis in the intensive care unit (ICU) and for patients with human immunodeficiency virus (HIV). Throughout, we also show that standalone decision trees as a baseline are noticeably less accurate than our tree-regularized deep models.
Granularity of Explanation
Thus far, we have implicitly assumed that there exists an optima for a deep model that is simulable while maintaining high performance. For many domains, this may not be true – we may rely on the complexity of a deep model where any strong regularization greatly increases error. In such cases, it may not be possible to have a model that is both accurate and well-approximated by a simple decision tree. To remedy this, we consider regional explanations that constrain the model independently across a partitioning of the input space. Coincidentally, this form of explanation is consistent with those of humans, whose models are typically context-dependent . For example, physicians in the intensive care unit do not expect treatment rules to be the same across different categories of patients. Constraining each region to be interpretable allows the deep model more flexibility than a global constraint, while still revealing prediction logic that can generalize to nearby inputs (in contrast to works on local explanation—[45, 47, 46]—which cannot indicate whether the same logic revealed for an input can be used for nearby inputs , an ambiguity that can lead to mistaken assumptions and poor decisions). In other words, we assume that even the most complex decision boundaries can be decomposed into an ensemble of simpler regional boundaries, each of which can be well-approximated by a decision tree. Furthermore, in many domains like medicine, human experts have very good intuitions for how to partition the input space. For example, an intensivist may care for patients in the surgical unit differently than patients in (non)-surgical units. By generalizing tree regularization to support regions, we can incorporate prior knowledge from domain experts to train simulable models.
While a straightforward conceptual leap, optimizing for simulable explanations across many regions poses a difficult technical challenge, facing issues with differentiability, efficiency, and a delicate balance of constraints between regions of varying size and complexity. In the methods, we will describe a computationally tractable and reliable approach to do so. Specifically, we show how to jointly train a deep model that both has high accuracy and is regionally simulable, and introduce innovations for stability in optimization. We first present a few synthetic experiments to build intuition and then, revisiting the clinical domain, we demonstrate that regional tree regularization achieves better performance while learning a much simpler decision function than any other regularizer.
2 Related work
Given a trained black box model, many approaches exist to explain what the model has learned. Works such as  expose the features a representation encodes but not the logic. [2, 26] provide an informative set of examples that summarize the system. Model distillation compress a source network into a smaller target neural network . However, even a small neural model may not be interpretable. Activation maximisation of neural networks  tries to find input patterns that produce the maximum response for a quantity of interest. However, a set of input patterns is not necessarily adequate to simulate a model’s predictions. Similarly, Layerwise-Relevance Propagation [7, 4] produces a heatmap of relevant information for prediction based on the aggregating the weights of a neural network. Again, learning a heatmap of the important information for predicting outcomes does not always enable human simulability, since we cannot necessarily step through each calculation that produces a decision.
In contrast, local approaches provide explanation for a specific input. ribeiro2016should ribeiro2016should show that using the weights of a sparse linear model, one can explain the decisions of a black box model in a small area near a fixed data point. This captures the intuition that even nonlinear functions are locally linear. Similarly, instead of a linear model, singh2016programs singh2016programs and koh2017understanding koh2017understanding output a simple program or an influence function, respectively. Other approaches have used input gradients (which can be thought of as infinitesimal perturbations) to characterize the local space [33, 47]. However, the notion of a local region in these works is both very small and often implicit; it does not match with human notions of contexts : a user may have difficulty knowing when local explanations apply and how they generalize to nearby inputs.
Optimizing for Interpretability
While there is little work on optimizing models for interpretability, there are some related threads. The first is model compression, which trains smaller models that perform similarly to large, black-box models (e.g. [buciluǎ2006model, 23, 6, 22]). Other efforts specifically train very sparse networks via L1 penalties  or even binary neural networks [51, 44] with the goal of faster computation. Edge and node regularization is commonly used to improve prediction accuracy [14, 40], and recently hu2016harnessingLogic hu2016harnessingLogic improve prediction accuracy by training neural networks so that predictions match a small list of known domain-specific first-order logic rules. Sometimes, these regularizations—which all smooth or simplify decision boundaries—can have the effect of also improving interpretability. However, there is no guarantee that these regularizations will do so; we emphasize that specifically training deep models to have easily-simulatable decision boundaries is (to our best knowledge) novel.
3 Background and Models
We consider supervised learning tasks given datasets oflabeled examples, , where each example (indexed by
) has an input feature vector
and a target output vector. and are the dimensionalities. For example, we will sometimes write , using to indicate indexing into the vector. We shall assume the targets are binary, though it is simple to extend to other types. When modeling time-series, each example sequence contains timesteps indexed by which each have a feature vector and an output . Formally, we write: and . Each value could be a prediction about the next timestep (e.g. the character at time ) or some other task-related annotation (e.g. if the patient became septic at time ).
We will primarily consider two kinds of deep models: multi-layer perceptrons and recurrent neural networks. That said, our approach is compatible with any architecture.
A multi-layer perceptron (MLP) makes predictions of the target via a function such that , where the vector represents all parameters of the network. Given a data set , our goal is to learn the optimal parameters to minimize the objective
For binary targets , the logistic loss (binary cross entropy) is an effective choice for . The regularization term can represent L1, L2 penalties (e.g. [14, 20, 40]) or our new family of regularizers.
Architecture diagrams for (a) gated recurrent units (GRU) and (b) a GRU and hidden markov model (HMM) hybrid. The orange triangle indicates the output used in surrogate training for tree regularization.
Recurrent Neural Networks with Gated Recurrent Units.
A recurrent neural network (RNN) takes as input an arbitrary length sequence and produces a “hidden state” sequence of the same length as the input. Each hidden state vector at timestep represents a location in a (possibly low-dimensional) “state space” with dimensions: (
is often chosen as a hyperparameter). RNNs perform sequentialnonlinear embedding of the form in hope that the state space location is a useful summary statistic for making predictions of the target at timestep . As written, is called a transition function parameterized by . Many different variants of the transition function architecture have been proposed to solve the challenge of capturing long-term dependencies. In this paper, we use gated recurrent units (GRUs) 
, which are simpler than other alternatives such as long short-term memory units (LSTMs). While GRUs are convenient, any differentiable RNN architecture is compatible with our new tree-regularization approach.
As review, we describe the evolution of a single GRU sequence, dropping the sequence index for readability. The GRU transition function produces the state vector (let denote the number of timesteps) from a previous state and an input vector , via the following feed-forward architecture:
The internal network nodes include candidate state gates , update gates and reset gates which have the same cardinality as the state vector
. Reset gates allow the network to forget past state vectors when set near zero via the logistic sigmoid nonlinearity, which critically adds a multiplicative expressivity to this model class. Update gates allow the network to either pass along the previous state vector unchanged or use the new candidate state vector instead. This architecture is diagrammed in Figure 1.
The predicted probability of the binary targetfor timestep is a sigmoid transformation of the state at time , . Here, weight vector represents the parameters of this individual output layer. We denote the parameters for the entire GRU-RNN model as , concatenating all component parameters. We can train GRU-RNN timeseries models (hereafter often just called GRUs) via the following loss minimization objective, sharing many similarities to the MLP’s loss (Eqn. 1):
where again defines a regularization cost, and represents the optimal parameters.
Hidden Markov Models with Stochastic Gradient Descent.
Besides recurrent neural networks, hidden markov models (or HMMs) are another class of sequence models that are commonly used to describe stochastic processes. Often, (as with RNNs) we are given a sequence of observed variables , and wish to derive a sequence of latent (or hidden) variables . We assume each latent variable, can take one of discrete states. In practice, these latent variables can be interpreted as an unsupervised clustering over the observed sequence. For our purposes, one can view the HMM as a stochastic RNN (added noise), making it a probabilistic generative model. To be tractable, the HMM makes a set of simplifying assumptions. The free parameters of an HMM define a prior,
, the probability distribution overstates for timestep 0; a transition matrix, which specifies a probability distribution over states for timestep given the state at timestep ; and an emission matrix, which specifies a probability distribution over (possibly continuous) observations at timestep given only the latent at timestep . Critically, this setup makes the Markov assumption – all information required to make a decision at timestep is present at timestep .
In our setting, we also have a sequence of known outputs,
. In some sense, we are not interested not in the latent states themselves but using them to classify an observation into output. If we decide upfront to specify a simple classifier on top of the latent variables (such as logistic regression), then we explicitly write the joint distribution over latents, observations, and outputs as:
where are the parameters specifying the prior, transition, and emission probabilities; are the parameters used in logistic regression; , the posterior distribution over states at timestep ;
represents a Sigmoid function. Therefore, we can train the HMM with stochastic gradient descent using the objective:
where contain all trainable parameters from a high-dimensional space of parameters . In other words, because we only desire maximum-a-posteriori (MAP) inference, we never need to sample from any of the distributions and therefore can differentiate this objective with standard techniques. Note that this is quite similar to the forward pass in the forward-backward algorithm.
Modeling the Residuals of a Hidden Markov Models
One strength of the HMM is that it is a fairly interpretable model. Often, the discrete latent states have contextual meaning such that we can analyze the predictions of HMM as conditioned completely on its state. However, for complex domains, discrete states (even for large ) might not be able to fully capture the true decision function, resulting in high prediction error. One option is to add a recurrent neural network, which are known to be high performing but un-interpretable, to model the residual errors when predicting the target outputs using the HMM belief (latent) states. If we can properly penalize the complexity of the deep model, then high quality predictions do not come at the price of a less interpretable model. In practice, the GRU and HMM can be trained jointly where the parameters of each model are kept independent. We call this model a GRU-HMM and use it in several experiments. Figure 1(b) recap the model architecture.
4 (Decision) Tree-Regularization
As presented in Eqns. 1 and 6, the regularizer is arbitrary. Common choices include norms to manage the sizes of and norms to manage the sparsity of . We now come to our core contribution: we replace with a novel tree-regularizer, denoted , that encourages the model to be simulable. Specifically, we shall encourage our deep models to be well-approximated by (small) decision trees. For clarity, we refer to the deep neural network that we are trying to regularize as the target neural model or target network.
To do so, we first fit a binary decision tree which accurately reproduces the target network’s thresholded binary predictions given input . The accuracy parameter is always kept fixed, so that the tree is forced to model the network well. Next, we penalize the network based on the complexity of learnt tree: a simple decision function can be explained with only a few branches whereas a complex function may need exceedingly large trees. With this in mind, we quantify complexity as the average decision path length (shorthand APL) —the average number of decision nodes that must be touched to make a prediction for an input (i.e. the number of nodes from root to leaf). We compute the average with respect to some designated reference dataset of example inputs from the training set. Thus, our regularizer is
where the APL function is detailed in Algorithm 1; represents the neural model; is a hyperparameter for training decision trees that controls the minimum number of training examples to define a leaf node. This definition of APL generalizes when the input data represents a timeseries. Algorithm 1 requires two subroutines, TrainTree and PathLength. Firstly, TrainTree trains a binary decision tree to accurately reproduce the provided labeled examples (recall ). For this we use the DecisionTree module distributed in Python’s scikit-learn , which fits a tree by maximizing information gain with Gini impurity. Generally, the runtime cost of this module scales superlinearly with the number of examples and linearly with the number of features for a total complexity of . In practice, we found that with , , fitting a decision tree takes 15.3 microseconds. These trees can give probabilistic predictions at each leaf. Next, PathLength counts how many nodes are needed to make a specific input to an output node in the provided decision tree (this is done programmatically by storing traversals).
We consider average path length a good proxy for simulability because human simulation requires stepping through every calculation required to make a prediction. Average path length (or APL) exactly counts the number of true-or-false boolean calculations needed to make an average prediction, assuming the model is a binary decision tree. In contrast, a metric such as the total number of nodes might penalize more accurate trees that have short paths for most examples but need more involved logic for few outliers. While a sensible choice, a few technical innovations are required to efficiently optimize the APL loss.
Making Tree Regularization Differentiable
Training decision trees is not differentiable, and thus the tree regularization loss from Equation 9 is not differentiable with respect to the network parameters (unlike standard regularizers such as or ). While one could resort to derivative-free optimization techniques  e.g. search algorithms, gradient descent has been an extremely fast and robust way of training neural networks .
A key technical contribution of our work is introducing and training a surrogate regularization function to map each parameter vector of the target neural model to an estimate of the APL. Our approximate function is implemented as a standalone multi-layer perceptron network and is critically differentiable. Let vector denote the trainable parameters of this chosen MLP surrogate. We can train
to be a good estimator by minimizing a squared error loss function:
where each is an instance of the entire set of parameters for the target neural model, is a regularization strength, and we assume we have a dataset of known parameter vectors and their associated true APLs: . This dataset can be assembled using the candidate parameter vectors obtained every gradient step while training our target neural model . Importantly, one can train the surrogate function in parallel with our network. In Figure 2(a), we show evidence that our surrogate predictor tracks the true average path length as we train the target predictor .
Compares the effects of parameter augmentation and random restarts (retraining): The blue line shows the true APL of the decision tree at each epoch. All other lines show predicted APL using the surrogate MLP. By augmenting and restarting, we significantly improve the ability of the surrogate model to track the changes in the ground truth.
Training the Surrogate Loss
In this section, we describe a few more considerations to improve surrogate quality. Firstly, even moderately-sized neural models can have parameter vectors with thousands of dimensions. Our labeled dataset for surrogate training – —will only have one example from each target network training iteration. Even with small batch sizes (more gradient steps), this dataset is too small. Thus, in early iterations, we will have only few examples from which to learn a good surrogate function . We resolve this challenge via augmenting our training set with additional examples: We randomly sample weight vectors and calculate the true APL , and we also perform several random restarts (initializing parameters with different random seeds) on the unregularized target network and use those weights in our training set.
A second challenge arises later in training: as the model parameters shift away from their initial values, parameters from earlier in optimization may not be as relevant in characterizing the current decision function of the target neural model. In practice, this is a function of the learning rate: a high step size will quickly render recent parameters ineffective for training a surrogate. To address this, for each epoch, we use examples only from the past iterations, where in practice, is empirically chosen. Consequently, using examples from a fixed window of iterations also speeds up training. Figure 2
(b) shows a comparison of the importance of these heuristics for efficient and accurate training—empirically, data augmentation for stabilizing surrogate training allows us to scale to neural networks with 100s of nodes. MLPs and GRUs of this size are already sufficient for many real problems, such as those we encounter in healthcare domains.
5 Demonstration: A Tree-Regularized MLP and RNN
We start by exploring two simple domains intended to build intuition for the tree regularization method. We first test the regularizer on MLPs in a two-dimensional classification task followed by a second prediction task with sequential data.
Tree-Regularized MLP: Noisy Parabola
We first show a binary classification task as demonstration. We call this task the 2D Parabola problem, because as Figure 3(a) shows, the training data consists of 2D input points whose two-class decision boundary is roughly shaped like a parabola. The true decision function is defined by .
We sampled 500 input points uniformly within the unit square and labeled those above the decision function as positive. To make it easy for models to overfit to more complex decision boundaries, we flipped 10% of the points in a region near the boundary. A random 30% were held out for testing. For the classifier, we train a 3-layer MLP with 100 first layer nodes, 100 second layer nodes, and 10 third layer nodes. This MLP is intentionally overly expressive to encourage overfitting and expose the impact of different forms of regularization: our proposed tree regularization , an L penalty on the weights , and an L penalty on the weights . For each regularization function, we train models at many different regularization strengths chosen to explore the full range of decision boundary complexities possible under each technique. For tree regularization, we model a surrogate with a 1-hidden layer MLP with 25 units. The surrogate is intentionally chosen to be small with few parameters. In practice, we bias towards simpler surrogate networks to ensure faster training – additionally, too complex of a surrogate would no longer preserve intepretability. The objective in Equation 1 was optimized via Adam gradient descent  using a batch size of 100 and a learning rate of 1e-3 for 250 epochs. These hyperparameters were set via cross validation using grid search.
To evaluate model simulability, we use APL. Since Algorithm 1 can compute the APL for any fixed deep model given its parameters, we use it to measure decision boundary complexity under any regularization, including L or L. Figure 4(b) shows each trained model as a single point in a 2D fitness space: the x-axis measures model complexity with APL, and the y-axis measures AUC (area under the ROC curve) prediction performance. These results show that simple L or L regularization does not produce models with both small node count and good predictions at any value of the regularization strength . As expected, large values for L and L only produce far-too-simple linear decision boundaries with poor accuracies. In contrast, our proposed tree regularization directly optimizes the MLP to have simple tree-like boundaries at high values which can still yield good predictions. The lower panes of Figure 4 shows these boundaries. Our tree regularization is uniquely able to create axis-aligned functions, because decision trees by definition parameterize functions with axis-aligned splits. Critically, these axis-aligned functions require very few nodes but are more effective than L and L counterparts.
Tree-Regularized GRU: Signal-and-noise HMM
Next, we analyze the performance of tree regularization on synthetic timeseries data. We generated a toy dataset of sequences, each with timesteps. Each timestep has a data vector of 14 binary features and a single binary output label . The data comes from two separate HMM processes. First, a “signal” HMM generates the first 7 data dimensions from 5 well-separated states. Second, an independent “noise” HMM generates the remaining 7 data dimensions from a different set of 5 states. The transition and emission matrices for both HMMs are shown in Fig. 6. The probabilities were chosen to make it difficult for a new HMM to learn. Each timestep’s output label is produced by a rule involving both the signal HMM’s generated observations and the signal HMM’s hidden state: the target is 1 at timestep only if both the first signal state is active and the first observation is turned on. We deliberately designed the generation so that neither logistic regression on inputs alone nor a GRU model that makes predictions from hidden states alone can perfectly separate this data.
As with the MLP, each regularizer (tree, L2, L1) is applied to the output node of the GRU across a range of strength parameters (see orange triangle in Figure 1). In training, we used 25 hidden dimensions for GRU models and 5 states for the HMM component of the GRU-HMM. All other choices are identical to the 2D Parabola setting.
Figure 5 compare the performance of regularized GRU and GRU-HMM models on the signal-and-noise HMM dataset. Since we can no longer easily visualize the decision boundary, we rely on plots like Figure 5(c,d) to measure regularization effectiveness. Many of the same patterns from the 2D Parabola experiments emerge here: tree regularized GRU models achieve much higher (held-out) AUC at lower APL. Further, L and L are quite unreliable at high regularization strengths, doing worse than a decision tree at low APL. All regularized models converge to the same performance as APL approaches 0 (random choice) and infinity (unregularized). Additionally, we include results for the GRU-HMM (d) whose performance is lower bounded by the performance of a standalone HMM (notice the scale of the y-axis). However, as before, tree regularization on the “GRU component" of the GRU-HMM quickly reaches near maximum performance with small APL (around 5). We hypothesize this is largely due to the compactly expressive nature of axis-aligned decision boundaries. Finally, Figure 5(a,b) show two “distilled" decision trees that are used to approximate the deep model in the last epoch of training. We can see that for small regularization strengths (a), the distilled tree is large and difficult to interpret. For larger strengths (b), the tree recovers the true generative process: predict positive output if and only if “x == 1 and s == 1 and s == 0”. The first component (x == 1) represents the first observation being 1; the second component (s == 1 and s == 0) represents the first state being active (recall that the emission distribution for this is state is [.5 .5 .5 .5 0 ]). A decision tree like this can be given to a human to help describe what mappings the deep model as learned. Critically, smaller decision trees are very easy to simulate.
6 Applications: Real-World Timeseries Data
Having explored a few synthetic environments, we now evaluate the tree regularizer on several real-world timeseries models in speech recognition and two sectors of healthcare. For each experiment below, we will compare a tree regularized GRU with an identical GRU regularized with L or L. We will also include a decision tree baseline where a tree classifier is fit directly on the observations. Additionally, we will compare the GRU results with GRU-HMM performance to gauge any benefits of residual training. For optimization, we use Adam with a learning rate of 1e-3, a batch size of 256, decision tree hyperparameter , train for 300 epochs, surrogate datasets of size , and retrain every 25 steps. Like above, we measure performance with AUC and simulability with APL for all models. Before sharing results, we briefly describe each task and domain.
We tested our approach on several real-world tasks: predicting medical outcomes of hospitalized septic patients, HIV therapy outcome prediction, and predicting stop phoneme groups from a selection of English speech recordings. To normalize scales, we independently standardized input features via z-scoring. Like in the demonstrations above, we compare tree regularization to Land L baselines. Additionally, we compare a tree-regularized deep network to a decision tree classifier.
Sepsis Critical Care (ICU): We study timeseries data for 11 786 septic ICU patients from the public MIMIC III dataset . We observe at each hour (timestep) a data vector of 35 vital signs and lab results as well as a label vector of 5 binary outcomes. Hourly data measures continuous input features such as respiration rate (RR), blood oxygen levels (paO), fluid levels, and more. Hourly binary labels include whether the patient died in hospital, whether the patient died after 90 days, and if mechanical ventilation was applied. Models are trained to predict all 5 output dimensions concurrently from one shared embedding. The average sequence length is 15 hours. 7 070 patients are used in training, 1 769 for validation, and 294 for test.
HIV Therapy Outcome (HIV): We make use of the EuResist Integrated Database  for 53,236 patients diagnosed with HIV. We consider 4-6 month intervals (corresponding to hospital visits) as time steps. Each data vector has 40 features, including blood counts, viral load measurements and lab results. Each output vector has 15 binary labels, including whether a therapy was successful in reducing viral load to below detection limits, if therapy caused CD4 blood cell counts to drop to dangerous levels (indicating AIDS), or if the patient suffered adherence issues to medication. The average sequence length is 14 steps. 37 618 patients are used for training; 7 986 for testing, and 7 632 for validation.
Phonetic Speech (TIMIT): Timeseries data containing broadband recordings of 630 speakers of eight major dialects of American English reading ten phonetically rich sentences . Each sentence contains time-aligned phonetic transcriptions of 60 phonemes. We focus on the problem of distinguishing stop phonemes (those that stop the flow of air, such as “b”, “d”, or “g”) from non-stops. Each timestep has one binary output indicating whether a stop phoneme occurs or not. There are 26 continuous features for each input vector representing the Mel-frequency cepstral coefficients and derivatives of the acoustic signal. There are 6 303 sequences: which we split into 3 697 for training, 925 for validation, and 1 681 for testing. The average length is 614 tokens in a sequence.
6.2 Results and Analysis
The results on ICU, HIV, and TIMIT share many consistent characteristics. We summarize the many experiments with analysis on common patterns and provide a few takeaways.
Tree-regularized models have fewer nodes than other forms of regularization.
Across tasks, we see that in the target regime of small decision trees (low APLs), our proposed regularization achieves higher prediction quality (higher AUCs). In the signal-and-noise HMM task, tree regularization (green line in Figure 5(d)) achieves AUC values near 0.9 when its trees have an average path length of 10. Similar models with L or L regularization reach this AUC only with trees that are nearly double in complexity (APL over 25). On both the SEPSIS (Figure 7) and TIMIT (Figure 7(a)), we see considerable gains in accuracy over other regularizers—AUC differences of 0.05 to 0.15—for path lengths of 20-30. On the HIV task in Figure 7(b), we see AUC differences of between 0.03 and 0.15 for path lengths of 10-15. Similarly, on the other HIV outcomes in Figures 7(c)-7(d), we see AUC differences of between 0.03 and 0.09 for path lengths of 20-30. These gains are particularly useful in determining how to administer subsequent therapies. More specifically, in domains where human-simulability is required, these increases in accuracy in the small-complexity regime can mean the difference between models that provide value on a task and models that are unusable, either because their performance is too poor or they are uninterpretable. We emphasize that across all tasks, standalone decision trees (marked by yellow dots in line plots) cannot reach this high-accuracy, low-complexity sweet spot, suggesting that tree regularization still enables neural networks to be nonlinear.
|SEPSIS (In-Hospital Mortality)||0.8144|
|SEPSIS (90-Day Mortality)||0.8845|
|SEPSIS (Mech. Vent.)||0.9008|
|SEPSIS (Median Vaso.)||0.9166|
|SEPSIS (Max Vaso.)||0.9260|
|HIV (CD4 below 200)||0.8426|
|HIV (Therapy Success)||0.8761|
|HIV (Poor Adherence)||0.9014|
|HIV (AIDS Onset)||0.9344|
Our learned decision-tree-like boundaries are interpretable.
Recall that a consequence of tree regularization is a distillation of the deep model as a decision tree. Across all tasks, these trees which mimic the predictions of tree-regularized deep models are small enough to simulate by hand and help users grasp the model’s nonlinear prediction logic. We have already seen this to be the case for the signal-and-noise HMM task. Similarly, in Figure 7, we show decision trees for two sepsis prediction tasks. We consulted a clinical expert on sepsis treatment, who noted that the trees helped him understand what the models might be doing and thus determine if he would trust the deep model. For example, he said that using FiO, RR, CO and paO to predict need for mechanical ventilation (Figure 6(f)) was sensible, as these all measure breathing quality. In contrast, the in-hospital mortality tree (Figure 6(b)) predicts that some young patients with no organ failure have high mortality rates while other young patients with organ failure have low mortality. These counter-intuitive results led to hypotheses about how uncaptured variables impact the training process. Such reasoning would not be possible from simple sensitivity analyses of the deep model. Moreover, our distilled trees for HIV such as those in Figure 7(d), are also interpretable. We observe that the baseline viral load and number of prior treatment lines are crucial factors in predicting whether a patient will suffer adherence issues. This is consistent with several medical studies which show that patients with higher viral loads at baseline tend to have faster disease progression, and hence have to take several drug cocktails to potentially combat resistance. This typically makes it more difficult for these patients to adhere to the medication.
Practical runtimes for tree regularization are less than twice that of simpler L2.
While our tree-regularized GRU with 10 states takes 3977 seconds per epoch on TIMIT, an equivalent L-regularized GRU takes 2116 seconds per epoch. Thus, our new method has cost less than twice the baseline even when the path-length surrogate is serially computed. Because the surrogate will in general be a much smaller model than the target neural model, we expect one could get much smaller per-epoch times by parallelizing the creation of
training pairs and the surrogate training. Additionally, 3 977 seconds includes the time needed to train the surrogate. In practice, we do this sparingly, only once every 25 epochs, yielding an amortized per-epoch cost of 2 191 seconds. More exhaustive runtime results with standard deviations over 10 epochs are in Table2.
. Low regularization causes high variance in tree size and shape. Sub-figures (d-f) show three of many variations.
Decision trees are stable over multiple optimization runs.
When tree regularization is strong (high ), the decision trees trained to match the predictions of deep models are stable. For both signal-and-noise and Sepsis tasks, multiple runs from different random restarts have nearly identical tree shape and size, perhaps differing by a few nodes. This stability is crucial to building trust in our method. On the signal-and-noise task (), 7 of 10 independent runs with random initializations resulted in trees of exactly the same structure, and the others closely resembled those sharing the same subtrees and features. On the other hand, with weak regularization (small ), variability in the distilled decision trees is high. See Figure 9 example trees under strong (a-c) and weak (d-f) regularization.
Target neural models are faithful to decision trees.
Fidelity is defined by  as the percentage of examples where the prediction of the target network and the decision tree agree. Thus, fidelity is a measurement of how faithful the deep network is to the distilled tree. A fidelity of 1 would indicate perfect agreement, in which the neural network has learned exactly the axis-aligned boundaries of a tree. In some sense, a fidelity of 1 is undesirable as we hope the deep network can make use of nonlinearity on the examples that a simulable tree would struggle with. Table 4 shows that the fidelity is high but not perfect, ranging from 0.80 to 0.94 across datasets.
The deep residual GRU-HMM can achieve high AUC with less complexity.
In Figure 10, we show the performance of jointly training the residual model, GRU-HMM, which combines an HMM with a tree-regularized GRU to improve its predictions. Here, the ideal APL is zero, indicating only the HMM makes predictions (only the GRU output node is regularized). For small APLs, the GRU-HMM substantially improves the original HMM’s predictions and has simulability gains over earlier GRUs. On the mechanical ventilation task, the GRU-HMM requires an APL of only 28 to reach AUC of 0.88, while the GRU alone with the same number of states requires a path length of 60 to reach the same AUC. This suggests that jointly-trained deep residual models may provide even better interpretability.
7 Regionally Faithful Explanations with Expert Priors
Global summaries such as L, L, or even tree regularization as presented above face a tough trade-off between human-simulability and being faithful to the underlying model. For instance, if we require a minimum fidelity of 0.95, it simply may not be possible to fit a faithful decision tree that is also human-simulable. In our experiments so far, we have been fortunate but there is little guarantee that such a tree must exist. More generally, for a complex enough domain (or for particularly difficult examples), it is again unreasonable to assume that there a decision tree can be small, bushy and performant. In such a case, tree regularization of a deep network may not be able to find a good compromise between accuracy and complexity. To get the best of both worlds, we will need a finer-grained definition of interpretability. Doing so might help find a new wealth of minima with high AUC and low APL (aka powerful yet simulable).
In this extension, we take advantage of the fact that domain experts may already have notions about how regions of the input space operate differently. For example, a clinical intensivist may already cognitively consider patients in the surgical intensive care unit (ICU) as different from patients in the cardiac ICU. Analogously, biologists may be happy with different models for classifying diseases in deciduous versus in coniferous plants. In fact, this way of partitioning thinking into independent compartments is a very general phenomena. Cognitive science literature tells us that people build context-dependent models of the world; they do not expect the same rule to apply in all circumstances .
Using this intuition, we divide the input space into exclusive regions. We assume that this division is available a priori via domain knowledge. In fact, this is a good opportunity to inject human beliefs into training the model. Formally, this translates into exclusive regions , where . We denote the observed dataset belonging to region as . Thus, we shall apply a regionally-faithful regularization that encourages the target neural model to be “simple” in every region (where a region corresponds to a human context). This partitioning of the input space into regions allows a regularized neural model to approximate very complex decision boundaries with simple components (in each region) still, thereby remaining simulable. We emphasize that our regional explanations are distinct from local explanations (e.g. ): the latter concerns itself with behavior within an -ball around a single data point, and makes no claims about general behavior across data points. In contrast, regional explanations are faithful over an entire region .
As a preview, Figure 11 highlights the distinctions between global, local, and regional tree regularization on a two-dimensional toy dataset where the true decision boundary is divided in half at . We see that global explanations (b) lack information about the input space and have to choose from a large set of possible solutions, converging to a different boundary. On the other hand, local explanations (c) produce simple boundaries around each data point but fail to capture global relationships, resulting in a complex overall decision function. Finally, regional explanations (d) over two regions divided at 0.4 share the benefits of (b) and (c), converging to the true boundary.
7.1 Regional Tree Regularization Objective
We now formally introduce regional tree regularization, which will require that the target neural model is well-approximated by a separate compact decision tree in every region. In contrast, we will rename the tree regularizer presented above as global tree regularization. Regionally simple decision boundaries are particularly hard to achieve with global tree regularization as the global APL metric may allow some human-relevant regions to be complex as long as most are simple. In particular, global tree regularization has an incentive to “ignore" simpler regions in order to minimize the regularization term (i.e. trivially prediction a single label). In many contexts, this behavior is undesirable. For example, if a clinician splits his/her patients by severity of illness, regularizing for simple global explanations can completely ignore a group of patients, rendering the machine learning system useless. To address this, we define our regional tree regularization as follows. First, let the APL for region be:
where the average path length, APL can be computed with Algorithm 2 (note that the target network and its parameters are the same for all regions , meaning a strong sharing of parameters across regions). For all future instances of computing APL, we use Algorithm 2, not Algorithm 1. We will elaborate on this distinction later. Note that is equivalent to global tree regularization as presented above. Next, to ensure that some regions cannot be made simple at the expense of others, we penalize only the most complex region:
in other words, a L norm over . The choice of L norm produces significantly different (and desirable) behavior than if we had simply used, for example, the L norm (or sum) over . Regularizing the sum of is equivalent to simply regularizing APL in a global tree that first branches by region. In contrast, as a nonlinear regularizer, L keeps all regions simple (aka low APL), while not penalizing regions that are already simple. We show an example of this effect in Figure 12: (a) shows a toy dataset with two regions (split by the black line): the left has a simple decision boundary dividing the region in half; the right has a more complex boundary. (b) and (c) then show two minima using L regional tree regularization. In both cases, one of the regions collapses to a trivial decision boundary (predicting all one label) to minimize the overall sum of APLs. On the other hand, since L is sparse, simple regions are not included in the objective, resulting in a more “balanced" regularization between regions (see d and e).
However, gradient descent with Equation 13 has several challenges. For example, both and the functions are non-differentiable. In the following, we describe how we address these challenges as well as concerns over optimization stability.
7.2 Gradient-based optimization with SparseMax
Gradient-based optimization of our proposed regularizer in Equation 13 is challenging because the max operator is not differentiable. Further, common differentiable approximations like softmax are dense (include non-zero contributions from all regions), which makes it difficult to focus on the most complex regions as max does (using a dense approximation of max would suffer from the same problems as using a L norm). Instead, we use the recently-proposed SparseMax transformation , which can focus on the most problematic regions (setting others to zero contribution) while remaining smooth and differentiable almost everywhere. Intuitively, SparseMax corresponds to a Euclidean projection of an input vector with entries (one APL per region) to an -length vector of non-negative entries that sums to one (i.e. the ()-dimensional probability simplex). When the projection lands on a boundary in the simplex (which is likely), then the resulting vector will be sparse. Efficient implementations of this projection are well-known  (see Algorithm 3), as are Jacobians for automatic differentiation . We refer to using SparseMax as L regional tree regularization (we call using the sum of the APLs L regional tree regularization).
7.3 Differentiable Regional Tree Regularization Loss
The regional APL is not differentiable as derivatives cannot flow through CART (the common method for training decision trees). To circumvent this, we again employ surrogate loss functions that map a parameter vector to an estimate of , the APL in region . This process is identical to global tree regularization but only for observations lying in region . Each surrogate has its own parameters . Specifically, we fit each by minimizing a mean squared error loss,
for all where is sampled from a dataset of known parameter vectors and their true APLs: . This dataset can be assembled using the candidate vectors obtained over gradient steps while training the target model . For regions, we curate one such dataset for each surrogate model.
The ability of each surrogate to stay faithful is a function of many factors. For global tree regularization (above), we used a fairly simple strategy for training a surrogate and found it sufficient; we find that especially when there are multiple surrogates to be maintained, sophistication is needed to keep the gradients accurate and the variances low. We describe these innovations in the next section.
7.4 Innovations for Optimization Stability
Optimizing multiple surrogate networks is a delicate operation. We found that depending on hyperparameters, the regional surrogates were unable to accurately predict the APL, causing regularization to fail. Further, repeated runs also often found different minima, making regional tree regularization feel unreliable. In short, it presents a much more difficult technical challenge than training a single surrogate as in global tree regularization. Below, we list optimization innovations that are essential to stabilize training, identify consistent minima, and get good APL prediction—all of which enabled robust regional tree regularization.
|Experiment||Mean MSE||Max MSE|
|No data aug.||0.069||0.987|
|With data aug.||0.015||0.298|
Data augmentation makes for a robust surrogate.
Especially for regional explanations, relatively small changes in the underlying model can mean large changes for the pattern in a specific region. As such, the surrogates need to be retrained frequently (e.g. every 50 gradient steps). The practice used in global tree regularization of computing the true APL for a dataset of the most recent is insufficient to learn the mapping from a thousand-dimensional weight vector to the APL. Using stale (very old) from previous epochs, however, would result in a poor surrogate model given outdated information. Previous heuristics as in random restarts or arbitrarily sampling random weights introduced more noise than signal. Thus, we supplement the dataset with randomly sampled weight vectors from the convex hull defined by the recent weights. Specifically, to generate a new , we sample from a Dirichlet distribution with categories and form a new parameter as a convex combination of the elements in . For each of these samples, we compute its true APL to train the surrogate. Table 15 shows this to reduce noise.
Decision trees should be pruned.
Given a dataset, , even with a fixed seed, there are many decision trees that can fit . One can always add additional subtrees that predict the same label as the parent node, thereby not effecting performance. This invariance again introduces difficulty in learning a surrogate model. To remedy this, we use reduced error pruning, which removes any subtree that does not effect performance as measured on a portion of not used in TrainTree. Note that line 4 in Algorithm 2 is not in the original tree regularization algorithm. Intuitively, pruning collapses the set of possible trees describing a single classifier to a singleton.
Decision trees should be trained deterministically.
CART is a common algorithm to train a decision tree. However, it has poor complexity in the number of features as it enumerates over all unique values per dimension. To scale efficiently, many open-source implementations (e.g. Scikit-Learn ) randomly sample a small subset of features. As such, independent training instances can lead to different decision trees of varying APL. For tree regularization, unexplained variance in APL means difficulty in training the surrogate model, since the function from model parameters to APL is no longer many-to-one. The error is compounded when there are many surrogates. To remedy this, we fix the random seed that governs the choice of features. As an example, Figure 14 shows the high variance of decision boundaries from a randomized treatment of fitting decision trees (a-d) on a very sparsely sampled data set, leading to higher error in surrogate predictions (Table 15). Setting the seed removes this variance.
A large learning rate will lead to thrashing.
As mentioned before, with many regions, small changes in the deep model can already have large effects on a region. If the learning rate is fast, each gradient step can lead to a dramatically different decision boundary than the previous. Thus, the function that each surrogate must learn is no longer continuous. Empirically, we found large learning rates to lead to thrashing, or oscillating between high and low APL where the surrogate is effectively memorizing the APL from the last epoch (with poor generalization to new ).
These optimization innovations are crucial for learning with regional tree regularization. Without them, optimization is very unstable, resulting in undesirable minima. Figure 16 shows a few examples in a synthetic dataset: without data augmentation (c), there are not enough examples to fully train each surrogate, resulting in poor estimates of in which we converge to the same minima as no regularization (b); without pruning and fixing seeds, the path lengths vary due to randomness in fitting a decision tree, which can lead to over- or under- estimating the true APL. As shown in (d), this leads to strange decision boundaries. Finally, (e) shows the effect of large learning rates that leads to thrashing, resulting in a trivial decision boundary in efforts to minimize the loss. Only with the optimization innovations (f), do we converge to a properly regularized decision boundary.
8 Demonstration: Five Rectangles Dataset
To build intuition, we present experiments in a toy setting: We define a ground-truth classification function composed of five rectangles (height of 0.5 and width of 1) in concatenated along the x-axis to span the domain of . The first three rectangles are centered at (shifted slightly downwards) while the remaining two rectangles are centered at (shifted slightly upwards). The training dataset is intended to be sparse, containing only 250 points with the labels of 5% of points randomly flipped to introduce noise and encourage overfitting. In contrast, the test dataset is densely sampled without noise. This is intended to model real-world settings where regional structure is only partially observable from an empirical dataset. It is exactly in these contexts that prior knowledge can be helpful.
|Model||Test Acc.||Test APL|
|Global Tree ()||0.8454||6.3398|
|L Regional Tree ()||0.9168||10.1223|
|L Regional Tree ()||0.9308||8.1962|
Figure 17 show the learned decision boundary with (b) no regularization, (c) L2 regularization, (d) global tree regularization, and (e,f) regional tree regularization. As global regularization is restricted to penalizing all data points evenly, it fails to find the happy medium between being too complex or too simple. In other words, increasing the regularization strength quickly causes the target neural model to collapse from a complex nonlinear decision boundary to a single axis-aligned boundary. As shown in (d), this fails to capture any structure imposed by the five rectangles111It might be possible to capture the true structure (in a simple domain such as this) with very careful tuning of the hyperparameters in global tree regularization. However, this is difficult to do consistently and regional tree regularization presents a much easier solution.. Similarly, if we increase the strength of L2 regularization even slightly from (c), the model collapses to the trivial solution of predicting entirely one label. Only regional tree regularization (e,f) is able to model the up-and-down curvature of the true decision function. With high , L regional tree regularization produces a more axis-aligned decision boundary than its L equivalent, primarily because we can regularize complex regions more harshly without collapsing simpler regions. Knowledge of the region divisions provides a model with prior information about underlying structure in the data; we should expect that with such information, a regionally regularized model can better prevent itself from over- or underfitting. We train for 500 epochs with a learning rate of 4e-3, a minibatch size of 32, retrain the surrogate function every epoch (a loop over the full training dataset) and sample 1000 weights from the convex hull each time. Decision trees were trained with . Table 3 compares metrics between the different regularizations: although the regional tree regularization is slightly more complex than global tree regularization, it comes with a large increase in accuracy.
9 Application: UC Irvine Prediction Tasks
Having seen a synthetic dataset, we transition to more realistic machine learning settings. Without loss of generality, we focus on feedforward networks, or MLPs. The same ideas of regional explanation using decision trees can be trivially extended to sequential models (like the GRU used above) or convolutional models. For the experiments below, we set the target neural model to a 6 layer MLP with 128, 128, 128, 64, 64, and
dimensional hidden layers respectively. The final layer contains a node for each output dimension. We use leaky ReLU nonlinearities in between each layer. Each surrogate remains a very shallow MLP.
9.1 Evaluation Metrics
We wish to compare models with global and regional explanations. However, given , and are not directly comparable: subtly, the APL of a global tree is often an overestimate for data points in a single region. To reconcile this, for any globally regularized model, we separately compute as an evaluation criterion. In this context,
is used only for evaluation; it does not appear in the objective nor training. We do the same for baseline models, L2 regularized models, and unregularized models. From this point on, if we refer to average path length (e.g. Test APL, APL, path length) outside of the objective, we are referring to the evaluation metric,.
We apply regional tree regularization to a suite of four popular machine learning datasets from UC Irvine repository . We briefly provide context for each dataset and show results comparing the regularization methods in effectiveness. We choose a generic method for defining regions to showcase the wide applicability of regional regularization: we use to fit a -means clustering model with . Each example is then assigned a number, . We define .
Bank Marketing (Bank): 45,211 rows collected from marketing campaigns for a bank . has 17 features describing a recipient of the campaign (age, education, etc). There is one binary ouput indicating whether the recipient subscribed.
MAGIC Gamma Telescope (Gamma): 19,020 samples from a simulator of high energy Gamma particles in an Cherenkov telescope. There are 11 input features for afterimages of photon pulses, and one binary output discriminating between signal and background.
Adult Income (Adult): 48,842 data points with 14 input features (age, sex, etc.), and a binary output indicating if an individual’s income exceeds $50,000 per year .
(Wine): 4,898 examples describing wine from Portugal. Each row has a quality score from 0 to 10 and eleven variables based on physicochemical tests for acidity, sugar, pH, etc. We binarize the target where a positive label indicates a score of at least 5.
In each dataset, the target neural model is trained for 500 epochs with 1e-4 learning rate using Adam  and a minibatch size of 128. We train under 20 different between 0.0001 and 10.0. We do not do early stopping to preserve overfitting effects. We use 250 samples from the convex hull and retrain every 50 gradient steps. We set for Wine and otherwise. Figure 18 (a-d) compare L2, global tree, and regional tree regularization with varying strengths. The points plotted show minima from 3 independent runs. We include three baselines: an unregularized model, a decision tree trained on and, a set of trees with one for each region (we call this: regional decision tree). For baseline trees, we vary where a higher is a more regularized decision tree.
Some patterns are apparent. First, an unregularized model (black) does poorly due to overfitting to a complex decision boundary, as the neural network is over-parameterized. Second, we find that L2 is not a desirable regularizer for simulatability as it is unable to find many minima in the low APL region (see Gamma, Adult, and Wine under roughly 5 APL). Any increase in regularization strength quickly causes the target neural model to decay to an F1 score of 0, in other words, one that predict a single label. We see similar behavior with global tree regularization, suggesting that finding low complexity minima is challenging under global constraints. Third, regional tree regularization achieves the highest test accuracy in all datasets. We find that in the lower APL area, regional explanations surpasses global explanations in performance. For example, in Bank, Gamma, Adult, and Wine, we can see this at 3-6, 4-7, 5-8, 3-4 APL respectively. This suggests, like in the toy example, that it is easier to regularize groups rather than the entire input space as a whole. In fact, unlike global regularization, models constrained regionally are able to reach a wealth of minima in the low APL area. Lastly, we note that with high regularization strengths, regional tree regularization mostly converges in performance with regional decision trees, which is sensible as the neural network prioritizes distillation over performance.
10 Application: Sepsis (ICU)
We revisit the Sepsis Critical Care dataset, only this time we apply regional tree regularization and compare to other regularizers, including global tree regularization.
APL for multiple outputs.
Previous datasets had only 1 binary output while Critical Care has 5. Fortunately, the definition of APL generalizes: compute the APL for each output dimension, and take the sum as the measure of complexity. This requires fitting trees.
We explore two methods of defining regions, both suggested by ICU physicians. The first defines three regions by sequential organ failure assessment (SOFA), a summary statistic that has historically been used for predicting ICU mortality. Using , the groups are defined by more than one standard deviation below the mean, one standard deviation from the mean, and more than one standard deviation above the mean. Intuitively, each group should encapsulate a very different type of patient. The second method clusters patients by the his/her careunit into five groups: MICU (medical), SICU (surgical), TSICU (trauma surgical), CCU (cardiac non-surgical), and CSRU (cardiac surgical). Again, patients who undergo surgery should behave differently than those with less-invasive operations.
Figure 19 compares different regularization schemes against baseline models for SOFA regions (a-d) and careunit regions (e-h). Overall, the patterns we discussed in the UCI datasets are consistent in this application. We especially highlight the inability (across the board) of global explanation to find many low complexity solutions. For example, in Figure 19 (a,c,e), the minima from global constraints stay very close to the unregularized minima. In other cases (f, g), global regularization finds very poor optima: reaching low accuracy with high APL. In contrast, region regularization consistently finds a good compromise between complexity and performance. In each subfigure, we can point to a span of APL at which the pink curve is much higher than all others. These results are from three runs, each with 20 different strengths.
Distilled decision trees.
A consequence of tree regularization is that every minima is associated with a set of trained trees. We can extract the trees that best approximate the target neural model, and rely on it for explanation. Figure 19 (i,j) show an example of two trees predicting ventilation plucked from a low APL - high AUC minima of a regional tree regularized model. We note that the composition of the trees are different, suggesting that they each capture a decision function biased to a region. Moreover, we can see that while Figure 19 (i) mostly predicts 0, Figure 19 (j) mostly predicts 1; this agrees with our intuition that SOFA scores are correlated with risk of mortality. Figure 19 (k,l) show similar findings for sedation. If we were to capture this behavior with a single decision tree, we would either lose granularity or be left with a very large tree.
Feedback from physicians.
We presented a set of 9 distilled trees from regional tree regularized models (1 for each output and SOFA region) to an expert intensivist for interpretation. Broadly, he found the regions beneficial as it allowed him to connect the model to his cognitive categories of patients—including those unlikely to need interventions. He verified that for predicting ventilation, GCS (mental status) should have been a key factor, and for predicting vasopressor use, the logic supported cases when vasopressors would likely be used versus other interventions (e.g. fluids if urine output is low). He was also able to make requests: for example, he asked if the effect of oxygen could have been a higher branch in the tree to better understand its effects on ventilation choices, and, noticing the similarities between the sedation and ventilation trees, pointed out that they were correlated and suggested defining new regions by both SOFA and ventilation status.
We highlight that this kind of reasoning about what the model is learning and how it can be improved is very valuable. Very few notions of interpretability in deep models offer the level of granularity and simulatability as regional tree explanations do.
11 Application: EuResist (HIV)
We again revisit the HIV dataset to compare global and regional explanations.
Defining regions in HIV.
We define regions based on the advice of medical experts. This is performed using a patient’s degree of immunosuppression at baseline (known as CDC staging). These groups are defined as: 200 cells/mm, 200 - 300 cells/mm, 300 - 500 cells/mm and 500 cells/mm . This choice of regions should characterize patients based on the initial severity of their infection; the lower the initial cell count, the more severe the infection.
Figure 20 compares different regularization schemes against baseline models across levels of immunosuppression. Overall, regional tree regularization produces more accurate predictions and provides simpler explanations across all outputs. For the case of predicting patient mortality in Fig 19(a), we tend to find more suitable optima across different patient groupings and can provide better regional explanations for these patients as a result. Here, we observe that patients with lower levels of immunosuppression tend to have lower risk of mortality. We also observe that patients with lower immunity at baseline are more likely to progress to AIDS. Similar inferences can be made for the other outputs. In each subfigure, we reiterate that there is a span of APL at which the pink curve is much higher than all others.
Distilled decision trees.
We extract decision trees that approximate the target model for multiple minima and use these as explanations. Fig 20 (e-g) show three trees where we have low APL and high AUC minima from a regional tree regularized model. Again, the trees look significantly different based on the decision function in a particular region. In particular, we observe that lower levels of immunity at baseline are associated with higher viral loads (lower viral suppression) and higher risk of mortality.
Feedback from physicians.
The trees were shown to a physician specializing in HIV treatment. He was able to simulate the model’s logic, and confirmed our observations about relationships between viral loads and mortality. In addition, he noted that when patients have lower baseline immunity, the trees for mortality contain several more drugs. This is consistent with medical knowledge, since patients with lower immunity tend to have more severe infections, and require more aggressive therapies to combat drug resistance.
12 Analysis for Regional Tree Regularization
We now summarize a few important outcomes from the regional experiments:
The most effective minima are found in the low APL, high AUC regime.
The ideal model is one that is highly performant and simulable. This translates to high F1/AUC scores near medium APL. Too large of an APL would be hard for an expert to understand. Too small of an APL would be too restrictive, resulting in no benefit from using a deep model. Across all experiments, we see that L region regularization is most adept at finding low APL and high AUC minima.
Global and local regularization are two extreme forms of regional regularization.
If , the full training dataset is contained in a single region, enforcing global explainability. If , then every data point has its own region i.e. local explainability.
Regularized deep models outperform trees.
Comparing regional tree-regularized models and regional decision trees, the former reach much higher AUC at equal APL.
Regional tree regularization produces regionally faithful decision trees.
Table 4 shows the fidelity of a deep model to its distilled tree. A score of 1.0 indicates that both models learned the same decision function. With a fidelity of 89%, the regularized model is “simple" in most cases, but can take advantage of deep nonlinearity with difficult examples.
Regional tree regularization is not computationally expensive.
Over 100 trials on Sepsis, an L2 model takes sec. per epoch; a global tree model takes sec. and sec. to (1) sample 1000 convex samples, (2) compute APL for , (3) train a surrogate model for 100 epochs; a regional tree model takes sec. and sec. for (1), (2), and training 5 surrogates. The increase in base cost is due to the extra forward pass through surrogate models to predict APL. The surrogate cost(s) are customizable depending on the size of , the number of training epochs, and the frequency of re-training. If is large, we need not re-train each surrogate. The choice of which regions to prioritize can be treated as a bandit problem.
Distilled decision trees are interpretable by domain experts.
We asked physicians in Critical Care and HIV to analyze the distilled decision trees from regional regularization. They were able to quickly understand the learned decision function per region, suggest improvements, and verify the logic.
Optimizing surrogates is much faster and more stable than gradient-free methods.
We tried alternative optimization methods that do not require differentiating through training a decision tree: (1) estimate gradients by perturbing inputs, (2) search algorithms like Nelder-Mead. However, we found these methods to either be unreasonably expensive, or easily stuck in local minima based on initialization.
Sparsity over regions is important.
We experimented with different “dense" norms: L, L, and a softmax approximation to L, all of which faced issues where regions with simpler decision boundaries a priori were over-regularized to trivial decision functions. Only with L (i.e. sparsemax) did we avoid this problem. As a consequence, in toy examples, we observe that sparsemax finds minima with more axis-aligned boundaries. In real world studies, we find sparsemax to lead to better performance in low/mid APL regimes.
Interpretability is a bottleneck preventing widespread acceptance of deep learning. We have introduced a family of novel tree-regularization techniques that encourages the complex decision boundaries of any differentiable model to be well-approximated by human-simulable functions, allowing domain experts to quickly understand and approximately compute what the model is doing. Overall, our training procedure is robust and efficient. Across three complex, real-world domains (HIV treatment, sepsis treatment, and human speech processing) our tree-regularized models provide gains in prediction accuracy in the regime of simpler, human-simulatable models. Finally, we then showed how to extend tree regularization to more regional-specific approximations of a loss, where experts can add prior knowledge about the structure of their domain. More broadly, our general training procedure could apply tree-regularization or other procedure-regularization to a wide class of popular models, helping us move beyond sparsity toward models humans can easily simulate and thus trust.
-  (2016) Auditing black-box models for indirect influence. In ICDM, Cited by: §1.
-  (2018) HIGHLIGHTS: summarizing agent behavior to people. In Proc. of the 17th International conference on Autonomous Agents and Multi-Agent Systems (AAMAS), Cited by: §2.
-  (2016) Blackbox and derivative-free optimization: theory, algorithms and applications. Springer. Cited by: §4.
-  (2015) On pixel-wise explanations for non-linear classifier decisions by layer-wise relevance propagation. PloS one 10 (7), pp. e0130140. Cited by: §2.
-  (2014) Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473. Cited by: §1.
-  (2015) Bayesian dark knowledge. In NIPS, Cited by: §2.
-  (2016) Layer-wise relevance propagation for deep neural network architectures. In Information Science and Applications (ICISA) 2016, pp. 913–922. Cited by: §2.
-  (2015) Deep computational phenotyping. In KDD, Cited by: §1.
-  (2017) Machine learning and prediction in medicineâbeyond the peak of inflated expectations.. N Engl J Med 376 (26), pp. 2507–2509. Cited by: §1.
-  (2014) Learning phrase representations using RNN encoder–decoder for statistical machine translation. In EMLNP, External Links: Cited by: §3.
-  (2016) Doctor AI: predicting clinical events via recurrent neural networks. In Machine Learning for Healthcare Conference, Cited by: §1.
-  (1996) Extracting tree-structured representations of trained networks. In NIPS, Cited by: Table 4, §6.2, Table 2.
-  (2017) UCI machine learning repository. University of California, Irvine, School of Information and Computer Sciences. External Links: Cited by: §9.2.
Improving generalization performance using double backpropagation. IEEE Transactions on Neural Networks 3 (6), pp. 991–997. Cited by: §2, §3.
-  (2008) Efficient projections onto the l 1-ball for learning in high dimensions. In Proceedings of the 25th international conference on Machine learning, pp. 272–279. Cited by: §7.2.
-  (2009) Visualizing higher-layer features of a deep network. Technical report Technical Report 1341, Department of Computer Science and Operations Research, University of Montreal. Cited by: §1.
-  (2017) Distilling a neural network into a soft decision tree. arXiv preprint arXiv:1711.09784. Cited by: §2.
-  (1993) TIMIT acoustic-phonetic continuous speech corpus. Linguistic Data Consortium 10 (5). Cited by: 3rd item.
-  (2017) Predicting intervention onset in the icu with switching state space models. AMIA Summits on Translational Science Proceedings 2017, pp. 82. Cited by: §1.
-  (2016) Deep learning. MIT Press. External Links: Cited by: §1, §3, §4.
-  (2016) Development and validation of a deep learning algorithm for detection of diabetic retinopathy in retinal fundus photographs. Jama 316 (22), pp. 2402–2410. Cited by: §1.
-  (2015) Learning both weights and connections for efficient neural network. In NIPS, Cited by: §2.
-  (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §2.
-  (1997) Long short-term memory. Neural computation 9 (8), pp. 1735–1780. Cited by: §3.
-  (2016) MIMIC-III, a freely accessible critical care database. Scientific Data 3. Cited by: 1st item.
-  (2014) The bayesian case model: a generative approach for case-based reasoning and prototype classification. In Advances in Neural Information Processing Systems, pp. 1952–1960. Cited by: §2.
-  (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §5, §9.2.
Scaling up the accuracy of naive-bayes classifiers: a decision-tree hybrid.. In KDD, Vol. 96, pp. 202–207. Cited by: §9.2.
-  (2012) ImageNet classification with deep convolutional neural networks. In NIPS, Cited by: §1.
-  (2016) Interpretable decision sets: a joint framework for description and prediction. In KDD, Cited by: §1.
-  (2016) The mythos of model interpretability. In ICML Workshop on Human Interpretability in Machine Learning, Cited by: §1.
-  (2016) An unexpected unity among methods for interpreting model predictions. arXiv preprint arXiv:1611.07478. Cited by: §1.
-  (2008) Visualizing data using t-sne. Journal of machine learning research 9 (Nov), pp. 2579–2605. Cited by: §2.
-  (2016) From softmax to sparsemax: a sparse model of attention and multi-label classification. In International Conference on Machine Learning, pp. 1614–1623. Cited by: §7.2.
Explanation in artificial intelligence: insights from the social sciences. Artificial Intelligence. Cited by: §1, §2, §7.
-  (2016) Deep patient: an unsupervised representation to predict the future of patients from the electronic health records. Scientific Reports 6 (26094). Cited by: §1, §1.
-  (2018) Methods for interpreting and understanding deep neural networks. Digital Signal Processing 73, pp. 1–15. Cited by: §2.
-  (2015) Inceptionism: going deeper into neural networks. Google Research Blog. Retrieved June 20 (14), pp. 5. Cited by: §2.
-  (2014) A data-driven approach to predict the success of bank telemarketing. Decision Support Systems 62, pp. 22–31. Cited by: §9.2.
-  (2017) Automatic node selection for deep neural networks using group lasso regularization. In ICASSP, Cited by: §2, §3.
-  (2005) Interim who clinical staging of hvi/aids and hiv/aids case definitions for surveillance: african region. Technical report Geneva: World Health Organization. Cited by: §11.
-  (2011) Scikit-learn: machine learning in Python. Journal of Machine Learning Research 12, pp. 2825–2830. Cited by: §4.
-  (2011) Scikit-learn: machine learning in python. Journal of machine learning research 12 (Oct), pp. 2825–2830. Cited by: §7.4.
-  (2016) XNOR-Net: imageNet classification using binary convolutional neural networks. In ECCV, Cited by: §2.
-  (2016) Why should I trust you?: explaining the predictions of any classifier. In KDD, Cited by: §1, §1, §7.
-  (2017) Right for the right reasons: training differentiable models by constraining their explanations. In IJCAI, Cited by: §1.
-  (2016) Grad-cam: why did you say that?. arXiv preprint arXiv:1611.07450. Cited by: §1, §2.
-  (20162016) Grad-CAM: visual explanations from deep networks via gradient-based localization. arXiv preprint arXiv:1610.02391v3. Cited by: §1.
-  (2016) Programs as black-box explanations. arXiv preprint arXiv:1611.07579. Cited by: §1.
-  (2014) Sequence to sequence learning with neural networks. In NIPS, Cited by: §1.
-  (2017) How to train a compact binary neural network with high accuracy?. In AAAI, Cited by: §2.
-  (2012) Predicting response to antiretroviral treatment by machine learning: the euresist project. Intervirology 55 (2), pp. 123–127. Cited by: 2nd item.
-  (2016) L1-regularized neural networks are improperly learnable in polynomial time. In ICML, Cited by: §2.