Learning Interpretable Models with Causal Guarantees

by   Carolyn Kim, et al.

Machine learning has shown much promise in helping improve the quality of medical, legal, and economic decision-making. In these applications, machine learning models must satisfy two important criteria: (i) they must be causal, since the goal is typically to predict individual treatment effects, and (ii) they must be interpretable, so that human decision makers can validate and trust the model predictions. There has recently been much progress along each direction independently, yet the state-of-the-art approaches are fundamentally incompatible. We propose a framework for learning causal interpretable models---from observational data---that can be used to predict individual treatment effects. Our framework can be used with any algorithm for learning interpretable models. Furthermore, we prove an error bound on the treatment effects predicted by our model. Finally, in an experiment on real-world data, we show that the models trained using our framework significantly outperform a number of baselines.



There are no comments yet.


page 1

page 2

page 3

page 4


Methods for Individual Treatment Assignment: An Application and Comparison for Playlist Generation

We present a systematic analysis of causal treatment assignment decision...

Active Learning for Decision-Making from Imbalanced Observational Data

Machine learning can help personalized decision support by learning mode...

Distilling Heterogeneity: From Explanations of Heterogeneous Treatment Effect Models to Interpretable Policies

Internet companies are increasingly using machine learning models to cre...

Pedagogical Rule Extraction for Learning Interpretable Models

Machine-learning models are ubiquitous. In some domains, for instance, i...

An interpretable machine learning framework for modelling human decision behavior

Machine learning has recently been widely adopted to address the manager...

Reliable Estimation of Individual Treatment Effect with Causal Information Bottleneck

Estimating individual level treatment effects (ITE) from observational d...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Machine learning is increasingly being used to help inform consequential decisions in healthcare, law, and finance. In these applications, the goal is often to predict the effect of some intervention (called a treatment effect)—e.g., the efficacy of a drug on a given patient (Consortium, 2009; Kim et al., 2011; Bastani & Bayati, 2015; Henry et al., 2015)

, the probability that a defendent in a court case is a flight risk 

(Kleinberg et al., 2017), or the probability that an applicant will repay a loan (Hardt et al., 2016). There are two important properties that these machine learning models must satisfy: (i) they must be must be causal (Rubin, 2005; Pearl, 2010), and (ii) they must be interpretable.

First, to predict treatment effects, our model must predict outcomes when the world is modified in some way (called a counterfactual outcome

). For example, to predict the efficacy of a drug on a patient, we need to know the patient’s outcome both when given the drug and when not given the drug. One way to predict counterfactual outcomes is to use randomized controlled experiments (RCTs)—by randomly assigning individuals to treatment and control groups, we can ensure that the model generalizes to predicting counterfactual outcomes. Indeed, RCTs are frequently used to estimate

average treatment effects (e.g., whether the drug is effective for the population as a whole). However, they are unsuitable for predicting individual treatment effects (ITEs)—such models have many more parameters, so much more training data is required. 111Individual treatment effects are also known as heterogeneous treatment effects, or conditional average treatment effects. Yet, the promise of machine learning is exactly to predict ITEs, which can be used to tailor decisions to specific individuals.

Instead, we consider the more common approach of predicting counterfactual outcomes based on observational data. In contrast to RCT data, individuals are selected into treatment and control groups by unknown mechanisms (Rubin, 2005; Shalit et al., 2017)

. For example, in observational data, sicker patients are more likely to receive drugs. Thus, our model may incorrectly conclude that drugs are ineffective, since individuals who do not take drugs are healthier than those who do. The problem is that supervised learning can only guarantee predictive performance on data that comes from the same distribution as the training data, but counterfactual outcomes do not satisfy this assumption.

To make progress, we have to make assumptions about the distribution of the observational data. Several algorithms along these lines have been proposed, including honest trees (Athey & Imbens, 2016), causal forests (Wager & Athey, 2017), propensity score weighting (Austin, 2011), instrumental variables (Wooldridge, 2015), and causal representations (Johansson et al., 2016; Shalit et al., 2017).

Second, the learned model must be interpretable—i.e., a human domain expert (e.g., a doctor) must be able to validate the model. Interpretability is important since there are often defects in the training data that cause the model to make preventable errors. Indeed, it has been shown that these issues often arise in practice (Caruana et al., 2015; Ribeiro et al., 2016; Bastani et al., 2017). Learning interpretable models is particularly important when there may be causal issues. In particular, there is often no way to validate the assumptions made by causal learning algorithms. For example, many approaches assume strong ignorability, which says that probability of selecting into treatment can be fully predicted from the covariates. However, this assumption often fails in practice (Louizos et al., 2017). Interpretability provides a way for experts to identify causal issues.

Many algorithms have been proposed for learning interpretable models, including decision trees 

(Breiman, 2017; Bastani et al., 2017), sparse linear models (Tibshirani, 1996; Ustun & Rudin, 2016), generalized additive models (Lou et al., 2012; Caruana et al., 2015), rule lists (Wang & Rudin, 2015; Yang et al., 2017; Angelino et al., 2017), decision sets (Lakkaraju et al., 2016), and programs (Ellis et al., 2015; Verma et al., 2018; Valkov et al., 2018; Ellis et al., 2018).

Thus, while there has been a variety of work on learning causal models and on learning interpretable models, there has been relatively little work on designing algorithms that are capable of achieving both desirable properties.

Our contributions.

We propose a general framework for learning interpretable models with causal guarantees. In particular, given any supervised learning algorithm for learning interpretable models, our framework converts into an algorithm for learning interpretable models that predict the ITE of an individual with covariates . Furthermore, we provide guarantees on the performance of the models learned using .

We build on recent work on causal representations (Johansson et al., 2016; Shalit et al., 2017), a general framework for converting any supervised learning algorithm into an algorithm for learning models that predict ITEs. Their key idea is to first learn a causal representation , where is an embedding space. Intuitively, is designed to eliminate the bias from using observational data. In particular, they then use to train a model on the embedding of the original dataset , where is the treatment and is the outcome. Finally, assuming strong ignorability, they prove bounds on the error of the following model for predicting ITEs:

The reason we cannot directly use their approach is that the causal representation is uninterpretable. In particular, their approach would use the interpretable learning algorithm to train an interpretable model on . However, remains uninterpretable since is uninterpretable—the problem is that the inputs to are the uninterpretable features .

We propose a solution to this problem inspired by model compression (Bucilua et al., 2006; Hinton et al., 2015). First, we use (Shalit et al., 2017) to learn an uninterpretable function . We refer to the function defined by as the oracle model. Then, we use to learn an interpretable model to approximate —i.e., for some distribution of our choosing,


where are i.i.d. samples. Then, we propose to use to predict ITEs.

It remains to choose in (1). We make a simple and intuitive choice—namely, the distribution over treatments that would have been induced by running an RCT (which we call the RCT distribution), where treatments are randomly assigned and are independent of the covariates . This choice amounts to using to label the unobserved counterfactual for each covariate in the original observational dataset, and then running on the combined dataset to train .

Intuitively, since RCTs can be used to predict ITEs, should have good performance as long as has good performance and is a good approximation of on the RCT distribution. Indeed, under these conditions, we prove a performance guarantee for analogous to the one available for the causal representations approach (Johansson et al., 2016; Shalit et al., 2017). Finally, in an experimental study, we show how our approach can be used to improve the performance of a wide range of interpretable models.

Related work.

There has been prior work proposing the “honest tree” algorithm for learning decision trees for prediting ITEs (Athey & Imbens, 2016). This work builds on the CART algorithm (Breiman, 2017)—in particular, they reduce the bias of CART by using different subsets of the training data to estimate the internal nodes and the leaf nodes. In contrast, our framework can be used to convert any interpretable learning algorithm into one for learning models for predicting ITEs. Furthermore, unlike their work, our approach comes with provable performance guarantees. Finally, we show in our experiments that our approach can substantially outperform theirs.

There has also been work using interpretability to identify causal issues in learned predictive models (Caruana et al., 2015; Ribeiro et al., 2016; Bastani et al., 2017). However, there is currently no way to fix these causal issues except by having an expert manually correct the model.

Finally, there has been a wide range of work using an uninterpretable oracle model to guide the learning of an interpretable model (Lakkaraju et al., 2017; Bastani et al., 2017; Verma et al., 2018; Frosst & Hinton, 2017; Bastani et al., 2018). Our work is the first to leverage this approach in the context of learning causal models.

2 Preliminaries

In this section, we give background on causal inference for estimating individual treatment effects (ITEs). Then, we summarize the approach of causal representations proposed in (Johansson et al., 2016; Shalit et al., 2017), as well as a bound they prove on the estimation error for their approach.

Potential outcomes framework.

We begin by describing the Rubin-Neyman potential outcomes framework (Rubin, 2005)

. Suppose we have a set of units, and we want to estimate the efficacy of a treatment for a given unit. We assume that each unit is associated with a covariate vector

(e.g., encoding patient-specific characteristics such as their healthcare history). Each unit is either assigned to the control group (denoted ) or to the treatment group (denoted ). Furthermore, each unit is associated with two potential outcomes—the outcome if the unit is assigned to control (i.e., ), and the outcome if the unit is assigned to treatment (i.e., ). The object of interest is the treatment effect , which informs the decision maker whether the unit would experience a better outcome under the treatment or under the control.

For example, units may be patients, and covariates may be patient-specific features such as biomarkers and healthcare history. The treatment may be prescribing a drug to the patient (so the control is not prescribing the drug). Then, may be how quickly the patient recovers when prescribed the drug, and is how quickly the patient recovers without the drug. Then, the treatment effect measures whether the drug helps the patient recover more quickly. Ideally, the patient would only be given the drug if .

Formally, each unit is associated with a tuple of random variables

. We assume that the covariate vector takes values in , and the potential outcomes take values in (of course, the treatment takes values in ). Furthermore, we assume that for each unit, this tuple is drawn i.i.d. from a distribution .

The fundamental challenge in causal inference is that for each unit, we only observe either or , but never both—in particular, for each unit, we only observe .

Definition 2.1.

The observed outcome is the factual outcome, and the unobserved outcome is the counterfactual outcome.

For example, if we give a patient the drug, we cannnot observe what would have happened without the drug.

Thus, we can only estimate the average over multiple units. If we average over the entire population, then we obtain average treatment effect (ATE)

However, the ATE does not yield any information about the efficacy of treatment on an individual unit. Instead, our goal is to estimate the efficacy of a treatment for an individual units based on their covariates.

Definition 2.2.

The individual treatment effect (ITE) is

To estimate the ITE, we make the following standard assumption about the treatment assignment mechanism (Johansson et al., 2016; Shalit et al., 2017).

Assumption 2.3.

We assume that the treatment assignment is strongly ignorable, i.e.,

For example, this assumption eliminates the possibility that we only observe for which . We also make the standard assumption that each unit has a nonzero probability of being assigned to each the control and the treatment.

Assumption 2.4.

We assume that for all ,

For example, this assumption eliminates the possibility that we never get observations of for a particular .

Our goal is to obtain an estimate of the ITE . A natural metric is our accuracy for predicting for a unit chosen at random from distribution .

Definition 2.5.

The expected precision in estimation of heterogenous effect (PEHE) (Hill, 2011) is


Causal representations.

Now, we describe the causal representations approach to estimating  (Johansson et al., 2016; Shalit et al., 2017). Suppose that we have observational data that we want to use to estimate . One way to do so is by estimating

and then using . Naïvely, we can use supervised learning to fit one model to predict on samples for which , yielding an estimate , and a second model to predict on samples for which , yielding an estimate .

This approach corresponds to fitting on samples from , and fitting on samples from . However, when evaluating the PEHE, we are also concerned with the errors of and on the counterfactual distributions and , respectively—i.e., when fitting , we also need samples , and when fitting , we also need samples . Otherwise, our estimate may be poor.

Thus, the error contains a term that comes from the discrepancy between the factual and counterfactual distributions. More precisely, by strong ignorability,

Comparing this with

we observe that the difference between these factual and counterfactual distributions are captured by the difference in the distributions and .

Definition 2.6.

The distribution of control units is , and the distribution of treated units is .

For this source of error to be small, we need to be similar to . However, for observational data, unlike RCT data, these distributions are given to us, and are not ones that we can choose.

As proposed by (Johansson et al., 2016; Shalit et al., 2017), one solution is to split the prediction problem into two steps: (i) learn a representation for some embedding space , and (ii) fit a predictive model on rather than on . Then, we can bound the error coming from the discrepancy between and by the discrepancy between and .

Assumption 2.7.

The representation is a twice-differentiable, one-to-one function. Without loss of generality, we assume that is the image of under , so that we can define an inverse .

Next, we define the distributions on induced by the distributions of treated units and of control units.

Definition 2.8.

For , define to be the density at of , and define to be the density of .

In other words, is the distribution of treated units on induced by , and is the distribution of control units on induced by .

We can now combine the estimates of into a single function. In particular, consider hypotheses of the form , where we estimate by and by . We are interested in the case where is derived from an estimator .

Definition 2.9.

Given a representation , we say a hypothesis factors through if there exists such that .

Then, we consider the following estimate of :

Definition 2.10.

The treatment effect estimate of the hypothesis for a unit with covariate is

We let . When factors through a representation —i.e., —we let .

Bound on causal error.

Our goal is to bound . We describe a bound on proven in (Shalit et al., 2017) for approaches to estimating the ITE

based on causal representations. We have two derived loss functions, one corresponding to the factual loss

and another corresponding to the counterfactual loss .222We assume that we are using the squared loss.

Definition 2.11.

Given , the expected loss for the unit and treatment pair is

and the expected factual and counterfactual losses of are

We break up the factual loss into two parts based on the following definition.

Definition 2.12.

The expected factual treated and control losses are

It follows immediately that

One term in the bound on from  (Shalit et al., 2017) quantifies the quality of , through the discrepancy between two distributions and . We use the following metric to measure this discrepancy:

Definition 2.13.

Suppose we have two probability distributions

and on . Given a family of functions , we have

To obtain guarantees, we require the following assumption on the function family :

Assumption 2.14.

The family satisfies

for some .

Then, one desirable property of the representation is for to be small. The other term in the bound on the error

comes from the variances of


Definition 2.15.

Given a distribution on , we denote the counterfactual density of by , defined by .

Definition 2.16.

Given a distribution on , the expected variances of and with respect to are

Furthermore, we let

We have the following bound on (Shalit et al., 2017):

Theorem 2.17.

For any factored as for some ,

This theorem shows that the error of our estimate of can be bounded by two terms. The first term

captures the error due to the test error of on the observational dataset. The second term

captures the error due to the mismatch between the distributions of treated units and of control units in the embedding space.

3 Interpretable Models for Individual Treatment Effect Estimation

Our learning framework can convert any algorithm for learning interpretable models in the supervised setting into an algorithm for learning interpretable models to predict individual treatment effects. Recall that the key issue with applying the causal representations approach is that we cannot simply train an interpretable model on the causal representation —in particular, the representation function is uninterpretable, so the composed model is uninterpretable.

Learning algorithm.

0:  Factual observations
Algorithm 1 Learning interpretable models with causal guarantees.

We propose an approach where we first train an uninterpretable oracle model using the causal representation approach, and then train an interpretable model to approximate . In particular, we prove that using our approach, as long as closely approximates , we can obtain a bound on the error of analogous to Theorem 2.17.

Let be the space of interpretable models considered by . Given observations from the distribution of , our goal is to learn an interpretable model for which we can provide causal guarantees. Let

be the set of datasets of any finite size (i.e., of size for ). Suppose we have a learning algorithm for interpretable models—i.e., given a dataset , then (usually approximately) solves the supervised learning problem


We use to denote the model returned by .

In addition, suppose we also have an oracle model that is not interpretable (so ), but whose associated estimate of is good. We assume that is learned using the causal representation approach described in Section 2—in particular, that it factors as .

Our approach is to train to approximate —i.e., , where for some set of covariate-treatment pairs. The key question is how to choose so that produces a good estimate of —i.e., so that is small.

Intuitively, when we have control over the treatment assignment—e.g., in a randomized controlled trial (RCT)—a good distribution to use is to uniformly randomly assign treatments. In particular, consider the following distribution:

Definition 3.1.

Given a distribution on , the RCT distribution derived from is the distribution on defined by


In other words, the random variables

have joint distribution

if is distributed as and is independent from .

Letting be the empirical distribution over covariates , we show below that is a good candidate for . In particular, with this choice, we can prove a bound on analogous to Theorem 2.17.

Given an observational dataset , our algorithm (shown in Algorithm 1) first uses the causal representations approach to learn an oracle model based on that has provable guarantees on (the subroutine LearnCR). Then, our algorithm constructs the distribution , where is the empirical distribution of covariates in . Next, our algorithm uses to label the points in , producing a dataset ; this step amounts to using to label the unobserved counterfactual for each covariate in . Finally, our algorithm runs the interpretable learning algorithm on the training set , and returns the result .

Model IHDP Jobs
Ours Baseline Ours Baseline Ours Baseline Ours Baseline
CFR-Net 0.926 0.02 0.271 0.01 0.235 0.02 0.086 0.03
CART (depth 6) 3.668 0.17 4.305 0.20 0.485 0.03 0.679 0.04 0.241 0.01 0.271 0.02 0.086 0.03 0.067 0.02
CART (depth 5) 3.824 0.18 4.436 0.21 0.492 0.02 0.725 0.05 0.241 0.01 0.280 0.02 0.086 0.03 0.069 0.02
CART (depth 4) 4.086 0.19 4.605 0.22 0.530 0.03 0.717 0.05 0.241 0.01 0.281 0.02 0.086 0.03 0.064 0.01
CART (depth 3) 4.462 0.21 4.930 0.23 0.585 0.03 0.795 0.05 0.241 0.01 0.285 0.02 0.086 0.03 0.067 0.02
Honest Tree (depth 6) 3.694 0.17 4.086 0.19 0.481 0.02 0.483 0.03 0.235 0.02 0.223 0.01 0.086 0.03 0.073 0.02
Honest Tree (depth 5) 3.760 0.17 4.098 0.19 0.488 0.02 0.486 0.03 0.235 0.02 0.216 0.01 0.086 0.03 0.074 0.02
Honest Tree (depth 4) 3.875 0.18 4.128 0.19 0.498 0.02 0.488 0.03 0.235 0.02 0.223 0.02 0.086 0.03 0.084 0.02
Honest Tree (depth 3) 4.090 0.19 4.237 0.20 0.535 0.03 0.498 0.03 0.235 0.02 0.236 0.01 0.086 0.03 0.080 0.02
LASSO 5.725 0.26 5.777 0.26 0.671 0.04 0.942 0.05 0.235 0.02 0.226 0.02 0.086 0.03 0.080 0.02
Kernel Ridge 2.077 0.09 3.190 0.14 0.361 0.02 0.562 0.02 0.235 0.02 0.234 0.02 0.086 0.03 0.077 0.02
GBM 1.845 0.09 2.799 0.14 0.352 0.02 0.453 0.03 0.241 0.01 0.223 0.02 0.086 0.03 0.080 0.02
Random Forest 2.905 0.14 3.653 0.19 0.439 0.02 0.621 0.04 0.241 0.01 0.239 0.01 0.086 0.03 0.073 0.02
Table 1: We show results comparing our approach to a baseline estimator for a number of model families on the IHDP and Jobs datasets. For each value, we show the mean

the standard error. We bold the better of the two values between ours and the baseline.

Bound on causal error.

We prove that as long as is close to on the distribution , where is the true covariate distribution, then is small.

Definition 3.2.

The relative error of to is

In other words, captures the test error of relative to the oracle model . Now, we can bound the generalization error by a combination of and the bound on .

Theorem 3.3.

For any function , and any function factored as for some , we have

We give a proof in Appendix A. Our bound has three terms—the first term captures the test error of relative to . The second two terms are from Theorem 2.17—the second term is the test error of on the observational dataset, and the third term captures the error due to the mismatch between the distributions of treated units and of control units in the latent representation.

While the bound in Theorem 3.3 is stated according to the exact error of with respect to , it can be straightforwardly converted to a finite sample bound using standard assumptions—e.g., that the model family has finite Rademacher complexity (Bartlett & Mendelson, 2002) and that solves (3) exactly. The other terms can similarly be converted into finite-sample bounds (Shalit et al., 2017).

Finally, note that we can estimate on a held-out test set of observational data—it is simply the loss of on the dataset constructed from constructed from the same way Algorithm 1 constructs from . As discussed in (Shalit et al., 2017), the remaining terms in the bound can similarly be estimated on . Thus, we can obtain an test set estimate of the bound in Theorem 3.3.

4 Experiments

Evaluating the performance of causal models is a challenging task, since ground truth data on individual treatment effects (ITEs) is difficult to obtain. Following previous work (Shalit et al., 2017), we evaluate our framework on the IHDP (Hill, 2011) and Jobs (LaLonde, 1986) datasets.

IHDP dataset.

We use a dataset for causal inference evaluation based on the Infant Health and Development Program, from (Hill, 2011) and preprocessed by (Shalit et al., 2017) using the NPCI package (Hill, 2016). The dataset has 747 units (139 treated, 708 control) and 25 covariates of children and their mothers. This dataset contains 1000 realizations of the outcomes with 63/27/10 train/validation/test splits. The outcomes in this dataset are simulated—i.e., we have ground truth values of the ITE for each unit. Using this ground truth, we can obtain a test set estimate of the error in the predicted ITE. Then, we report the mean and standard errors of , as well as the absolute error in the average treatment effect (ATE)

over the 1000 realizations. Our primary metric of interest is , which measures predictive accuracy of ITEs, whereas measures predictive accuracy of the ATE.

Jobs dataset.

We use the Jobs dataset from (Shalit et al., 2017) based on (LaLonde, 1986), where the binary outcome is employment (versus unemployment). This dataset (3212 individuals) is a combination of data from a randomized trial (297 treated and 425 control) and data from an observational study (2490 control). A difficulty with the Jobs dataset is that we do not have ground truth on the ITEs. Instead, we use a metric based proposed in (Shalit et al., 2017), which evaluates a policy that makes treatment decisions based on the predictions of . In particular, recall that is the predicted outcome for a unit with covariates and treatment . We consider the policy that assigns this unit to treatment if the predicted treatment effect is positive—i.e., if . Then, the policy risk

measures the quality of outcomes on average over the test population. For any predictor , we can estimate on the randomized subset of the Jobs data as follows:

We also use the randomized subset to estimate the “ground truth” effect. In particular, let be the set of units in the treated subgroup, the randomized study, and in the control subgroup, respectively (note that ). We report the treatment effect on the treated by

and use as one metric

We report the mean and standard error of and over 10 outcomes with 56/24/20 train/validation/test splits. For this study, our primary outcome of interest is the , since it to some degree measures the predictive accuracy of ITEs; in contrast, similar to , measures the predictive accuracy of a population average effect.

Oracle model.

For , we train a CFR-net from (Shalit et al., 2017), which has 3 fully connected exponential-linear layers for each the representation function and for the prediction function , with layer sizes 100 for all layers used for Jobs and 200 and 100 for the representation and hypothesis layers for IHDP. For IHDP, we used mean squared loss; for Jobs, we use logistic loss.

Figure 1: Performance (in terms of ) of CART (left) and honest trees (right) using our approach (black, solid) and the baseline approach (red, dashed) on the IHDP dataset, as a function of the depth of the decision tree.

Interpretable models.

We evaluate the performance of our approach on a variety of models with a range of interpretability: CART trees (Breiman, 2017), honest trees (Athey & Imbens, 2016)

, LASSO regression 

(Tibshirani, 1996)

, kernel ridge regression 

(Murphy, 2012)

, gradient boosted models (GBMs) 

(Friedman, 2001), and random forests (Breiman, 2001). For each model family, we train one model using our approach, and a baseline model using only the observational data for training.

Of these models, only honest trees are designed to handle causality; however, their focus is on obtaining unbiased estimates rather than low-variance estimates. In particular, they split the dataset into two, using the first part to estimate splits and the second to estimate values at the leaf nodes. This approach ensures that the estimates at the leaf nodes are unbiased, but also greatly increases variance since they are only using half the data at each point.


We show results in Table 1. Note that we run CART and honest trees with different maximum depths; Figure 1 shows how scales with depth on IHDP.


On the IHPD dataset, our approach uniformly outperforms the baseline approach in terms of , which measures performance on predicting ITEs. Even on predicting ATEs, our approach mostly outperforms the baseline; the only exception are honest trees, which are interpretable models tailored towards estimating treatment effects. As we discussed before, honest trees are focused on reducing bias at the expense of increased variance. Otherwise, we observe the usual trends—more complex models (e.g., GBMs and random forests) outperform more interpretable models (LASSO, CART, honest trees).

On the Jobs dataset, our performance was more mixed. Our approach significantly benefited CART in terms of , as well as honest trees of depth 3. However, for the remaining models (including honest trees of depth ), the baseline approach outperformed ours.

The problem is that the oracle model CFR-Net did not perform as well as even some of the simpler models—indeed, the baseline honest tree of depth 5 was the best performing model on the dataset. In particular, we were unable to replicate the results of (Shalit et al., 2017), despite using their available code and obtaining the original train/validation/test splits from the authors. The gap in our performance ( 0.235) relative to ones reported in  (Shalit et al., 2017) ( 0.21) is not very large; however, even in their results, a number of baseline models perform very similarly (or even better) than CFR-Net.

As a consequence, many of the models trained using our approach achieved performance equal to that of CFR-Net—in particular, since we are training our models using labels provided by CFR-Net as ground truth, we cannot expect to do better than than their performance (i.e., and ). Furthermore, CFR-Net appears to have learned a relatively simple function, since LASSO and kernel ridge regression both performed exactly as well CFR-Net when trained to imitate it; similarly, none of the CART and honest trees trained to imitate CFR-Net grew beyond depth 3.

In summary, while our approach proved less useful for the Jobs dataset, where simple models already perform as well as (or better than) more expressive models, our results on the IHDP dataset clearly demonstrate the potential for our approach to substantially improve the performance of interpretable learning algorithms used to predict ITEs.

5 Conclusion

We have proposed a general framework for learning interpretable models with causal guarantees. A number of directions remain for future work. Most importantly, as with previous work, our approach makes the strong ignorability assumption. The predominant approach to avoiding this assumption is to use instrumental variables. Incorporating these ideas with the instrumental variables framework could enable causal guarantees without strong ignorability.


  • Angelino et al. (2017) Angelino, E., Larus-Stone, N., Alabi, D., Seltzer, M., and Rudin, C. Learning certifiably optimal rule lists. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 35–44. ACM, 2017.
  • Athey & Imbens (2016) Athey, S. and Imbens, G. Recursive partitioning for heterogeneous causal effects. Proceedings of the National Academy of Sciences, 113(27):7353–7360, 2016.
  • Austin (2011) Austin, P. C. An introduction to propensity score methods for reducing the effects of confounding in observational studies. Multivariate behavioral research, 46(3):399–424, 2011.
  • Bartlett & Mendelson (2002) Bartlett, P. L. and Mendelson, S. Rademacher and gaussian complexities: Risk bounds and structural results. Journal of Machine Learning Research, 3(Nov):463–482, 2002.
  • Bastani & Bayati (2015) Bastani, H. and Bayati, M. Online decision-making with high-dimensional covariates. 2015.
  • Bastani et al. (2017) Bastani, O., Kim, C., and Bastani, H. Interpreting blackbox models via model extraction. arXiv preprint arXiv:1705.08504, 2017.
  • Bastani et al. (2018) Bastani, O., Pu, Y., and Solar-Lezama, A.

    Verifiable reinforcement learning via policy extraction.

    In NIPS, 2018.
  • Breiman (2001) Breiman, L. Random forests. Machine learning, 45(1):5–32, 2001.
  • Breiman (2017) Breiman, L. Classification and regression trees. Routledge, 2017.
  • Bucilua et al. (2006) Bucilua, C., Caruana, R., and Niculescu-Mizil, A. Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 535–541. ACM, 2006.
  • Caruana et al. (2015) Caruana, R., Lou, Y., Gehrke, J., Koch, P., Sturm, M., and Elhadad, N. Intelligible models for healthcare: Predicting pneumonia risk and hospital 30-day readmission. In Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1721–1730. ACM, 2015.
  • Consortium (2009) Consortium, I. W. P. Estimation of the warfarin dose with clinical and pharmacogenetic data. New England Journal of Medicine, 360(8):753–764, 2009.
  • Ellis et al. (2015) Ellis, K., Solar-Lezama, A., and Tenenbaum, J. Unsupervised learning by program synthesis. In Advances in neural information processing systems, pp. 973–981, 2015.
  • Ellis et al. (2018) Ellis, K., Ritchie, D., Solar-Lezama, A., and Tenenbaum, J. Learning to infer graphics programs from hand-drawn images. In Advances in Neural Information Processing Systems, pp. 6062–6071, 2018.
  • Friedman (2001) Friedman, J. H. Greedy function approximation: a gradient boosting machine. Annals of statistics, pp. 1189–1232, 2001.
  • Frosst & Hinton (2017) Frosst, N. and Hinton, G. Distilling a neural network into a soft decision tree. arXiv preprint arXiv:1711.09784, 2017.
  • Hardt et al. (2016) Hardt, M., Price, E., Srebro, N., et al. Equality of opportunity in supervised learning. In Advances in neural information processing systems, pp. 3315–3323, 2016.
  • Henry et al. (2015) Henry, K. E., Hager, D. N., Pronovost, P. J., and Saria, S. A targeted real-time early warning score (trewscore) for septic shock. Science translational medicine, 7(299):299ra122–299ra122, 2015.
  • Hill (2011) Hill, J. L. Bayesian nonparametric modeling for causal inference. Journal of Computational and Graphical Statistics, 2011.
  • Hill (2016) Hill, J. L. Npci: Non-parametrics for causal inference. https://github.com/vdorie/npci, 2016.
  • Hinton et al. (2015) Hinton, G., Vinyals, O., and Dean, J. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  • Johansson et al. (2016) Johansson, F., Shalit, U., and Sontag, D. Learning representations for counterfactual inference. In International Conference on Machine Learning, pp. 3020–3029, 2016.
  • Kim et al. (2011) Kim, E. S., Herbst, R. S., Wistuba, I. I., Lee, J. J., Blumenschein, G. R., Tsao, A., Stewart, D. J., Hicks, M. E., Erasmus, J., Gupta, S., et al. The battle trial: personalizing therapy for lung cancer. Cancer discovery, 2011.
  • Kleinberg et al. (2017) Kleinberg, J., Lakkaraju, H., Leskovec, J., Ludwig, J., and Mullainathan, S. Human decisions and machine predictions. The quarterly journal of economics, 133(1):237–293, 2017.
  • Lakkaraju et al. (2016) Lakkaraju, H., Bach, S. H., and Leskovec, J. Interpretable decision sets: A joint framework for description and prediction. In Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining, pp. 1675–1684. ACM, 2016.
  • Lakkaraju et al. (2017) Lakkaraju, H., Kamar, E., Caruana, R., and Leskovec, J. Interpretable & explorable approximations of black box models. arXiv preprint arXiv:1707.01154, 2017.
  • LaLonde (1986) LaLonde, R. J. Evaluating the econometric evaluations of training programs with experimental data. The American economic review, pp. 604–620, 1986.
  • Lou et al. (2012) Lou, Y., Caruana, R., and Gehrke, J. Intelligible models for classification and regression. In Proceedings of the 18th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 150–158. ACM, 2012.
  • Louizos et al. (2017) Louizos, C., Shalit, U., Mooij, J. M., Sontag, D., Zemel, R., and Welling, M. Causal effect inference with deep latent-variable models. In Advances in Neural Information Processing Systems, pp. 6446–6456, 2017.
  • Murphy (2012) Murphy, K. P. Machine Learning: A Probabilistic Perspective. The MIT Press, 2012.
  • Pearl (2010) Pearl, J. Causal inference. In Causality: Objectives and Assessment, pp. 39–58, 2010.
  • Ribeiro et al. (2016) Ribeiro, M. T., Singh, S., and Guestrin, C. Model-agnostic interpretability of machine learning. In KDD, 2016.
  • Rubin (2005) Rubin, D. B. Causal inference using potential outcomes: Design, modeling, decisions. Journal of the American Statistical Association, 2005.
  • Shalit et al. (2017) Shalit, U., Johansson, F. D., and Sontag, D. Estimating individual treatment effect: generalization bounds and algorithms. In Proceedings of the 34th International Conference on Machine Learning (ICML), 2017.
  • Tibshirani (1996) Tibshirani, R. Regression shrinkage and selection via the lasso. Journal of the Royal Statistical Society. Series B (Methodological), pp. 267–288, 1996.
  • Ustun & Rudin (2016) Ustun, B. and Rudin, C. Supersparse linear integer models for optimized medical scoring systems. Machine Learning, 102(3):349–391, 2016.
  • Valkov et al. (2018) Valkov, L., Chaudhari, D., Srivastava, A., Sutton, C., and Chaudhuri, S. Houdini: Lifelong learning as program synthesis. In Advances in Neural Information Processing Systems, pp. 8701–8712, 2018.
  • Verma et al. (2018) Verma, A., Murali, V., Singh, R., Kohli, P., and Chaudhuri, S. Programmatically interpretable reinforcement learning. In ICML, 2018.
  • Wager & Athey (2017) Wager, S. and Athey, S. Estimation and inference of heterogeneous treatment effects using random forests. Journal of the American Statistical Association, (just-accepted), 2017.
  • Wang & Rudin (2015) Wang, F. and Rudin, C. Falling rule lists. In Artificial Intelligence and Statistics, pp. 1013–1022, 2015.
  • Wooldridge (2015) Wooldridge, J. M. Introductory econometrics: A modern approach. Nelson Education, 2015.
  • Yang et al. (2017) Yang, H., Rudin, C., and Seltzer, M. Scalable bayesian rule lists. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 3921–3930. JMLR. org, 2017.

Appendix A Proof of Theorem 3.3

From the proof of Theorem 1 in (Shalit et al., 2017), we have


Then, we have


where equation (5) follows from Lemma A5 in (Shalit et al., 2017). Similarly,

Plugging this in equation 4, we obtain