Interpretability via Model Extraction

06/29/2017 ∙ by Osbert Bastani, et al. ∙ 0

The ability to interpret machine learning models has become increasingly important now that machine learning is used to inform consequential decisions. We propose an approach called model extraction for interpreting complex, blackbox models. Our approach approximates the complex model using a much more interpretable model; as long as the approximation quality is good, then statistical properties of the complex model are reflected in the interpretable model. We show how model extraction can be used to understand and debug random forests and neural nets trained on several datasets from the UCI Machine Learning Repository, as well as control policies learned for several classical reinforcement learning problems.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.


We use our model extraction algorithm to interpret several supervised learning models trained on datasets from the UCI Machine Learning Repository 

[uci_repository], as well as a learned control policy from OpenAI Gym [openai_cartpole], i.e., the learned control policy .

Comparison to CART

! Description of Problem Instance Absolute Relative Dataset Task Samples Features Model breast cancer [wolberg1990multisurface] classification 569 32 random forest 0.966 0.942 0.934 0.957 0.945 student grade [cortez2008using] regression 382 33 random forest 4.47 4.70 5.10 0.40 0.64 wine origin [forina1991extendible] classification 178 13 random forest 0.981 0.925 0.890 0.938 0.890 wine origin [forina1991extendible] classification 178 13 neural net 0.795 0.755 0.751 0.913 0.905 cartpole [cartpole_problem] reinforcement learning 100 4 control policy 200.0 190.0 35.6 86.8% 83.8%

Table : Comparison of the decision tree extracted by our algorithm to the one extracted by the baseline. We show absolute performance on ground truth and performance relative to the model . For classification (resp., regression), performance is score (resp., MSE) on the test set. For reinforcement learning, it is accuracy on the test set for relative performance, and estimated reward using the decision tree as the policy for absolute performance. We bold the higher score between and .

First, we compare our algorithm to a baseline that uses CART to train a decision tree approximating on the training set . For both algorithms, we restrict the decision tree to have 31 nodes. We show results in Table Document. We show the test set performance of the extracted tree compared to ground truth (or for MDPs, estimated the reward when it is used as a policy), as well as the relative performance compared to the model on the same test set. Note that our goal is to obtain high relative performance: a better approximation of is a better interpretation of , even if has poor performance. Our algorithm outperforms the baseline on every problem instance.

Examples of Use Cases

We show how the extracted decision trees can be used to interpret and debug models.

Use of invalid features.

Using an invalid feature is a common problem when training models. In particular, some datasets contain multiple response variables; then, one response should not be used to predict the other. For example, the breast cancer dataset contains two response variables indicating cancer recurrence: the length of time before recurrence and whether recurrence occurs within 24 months. This issue can be thought of as a special case of using a non-causal feature, an important problem in healthcare settings. We train a random forest

to predict whether recurrence occurs within 24 months, where time to recurrence is incorrectly included as a feature. Then, we extract a decision tree approximating of size nodes, using 10 random splits of the dataset into training and test sets. The invalid feature occured in every extracted tree, and as the top branch in 6 of the 10 trees.

Use of prejudiced features.

We can use our algorithm to evaluate how a model depends on prejudiced features. For example, gender is a feature in the student grade dataset, and may be considered sensitive when estimating student performance. However, if we simply omit gender, then may reconstruct it from the remaining features. For a model trained with gender available, we show how a decision tree extracted from can be used to understand how depends on gender. Our approach does not guarantee fairness, but it can be useful for evaluating the fairness of . We extract decision trees from the random forests trained on 10 random splits of the student grades dataset. The top features are consistently grades in other classes or number of failing grades received in the past. Gender occurs below these features (at the fourth or fifth level) in 7 of 10 of the trees. We can estimate the overall effect of changing the gender label: Δ=E_x∼P[f(x)∣male]-E_x∼P[f(x)∣female]. When gender occurs, is between 0.31 and 0.70 grade points (average 0.49) out of 20 total grade points. For the remaining models, is between 0.11 and 0.32 (average 0.25). Thus, the extracted tree includes gender when has a relatively large dependence on gender. Furthermore, because approximates , we can use it to identify a subgroup of students where has particularly strong dependence on gender. In particular, points that flow to the internal node of branching on gender are a subset of inputs whose label is determined in part by gender. We can use to measure the dependence on gender within this subset: Δ_N=E_x∼P[f(x)∣C_N_L]-E_x∼P[f(x)∣C_N_R], where and are the left and right children of . Also, we can estimate the fraction of students in this subset using the test set, i.e., . Finally, measures the fraction of the overall dependence of on gender that is accounted for by the subtree rooted at . For models where gender occurs in the extracted tree, the subgroup effect size ranged from 0.44 to 0.77 grade points, and the estimated fraction of students in this subroup ranged from 18.3% to 39.1%. The two trees that had the largest effect size had of 0.77 and 0.43, resp., and of 39.1% and 35.7%, resp. The identified subgroup accounted for 67.3% and 65.6% of the total effect of gender, resp. Having identified a subgroup of students likely to be adversely affected, the user might be able to train a better model specifically for this subgroup. In 5 of the 7 extracted trees where gender occurs, the affected students were students with low grades, in particular, the 27% of students who scored fewer than 8.5 points in another class. This fine-grained understanding of relies on the extracted model, and cannot be obtained using feature importance metrics alone.

Comparing models.

We can use the extracted decision trees to compare different models trained on the same dataset, and gain insight into why some models perform better than others. For example, random forests trained on the wine origin dataset performed very well, all achieving an score of at least 0.961. In contrast, the performance of the neural nets was bimodal—5 had score of at least 0.955, and the remaining had an score of at most 0.741. We examined the top 3 layers of the extracted decision trees , and made two observations. First, occurrence of the feature “chlorides” in was almost perfectly correlated with poor performance of the neural nets. This feature occured in only one of the 10 trees extracted from random forests, and in none of the trees extracted from high performing neural nets. A weaker observation was the branching of on the feature “alcohol”, which is a very important feature—it is the top branch for all but one of the 20 extracted decision trees. For the high performing models, the branch threshold tended to be higher (749.8 to 999.6) than those for the poorly performing models (574.4 to 837.3). The latter observation relies on having an extracted model—feature influence metrics alone are insufficient.

Understanding control policies.

We can use the extracted decision tree to understand a control policy. For example, we extracted a decision tree of size from the cartpole control policy. While its estimated reward of 152.3 is lower than for larger trees, it captures a significant fraction of the policy behavior. The tree says to move the cart to the right exactly when (pole velocity≥-0.286)∧(pole angle≥-0.071), where the pole velocity is in and the pole angle is in . In other words, move the cart to the right exactly when the pole is already on the right relative to the cart, and the pole is also moving toward the left (or more precisely, not moving fast enough toward the right). This policy is asymmetric, focusing on states where the cart is moving to the left. Examining an animation of simulation, this bias occurs because the cart initially moves toward the left, so the portion of the state space where the cart is moving toward the right is relatively unexplored.


We have proposed model extraction as an approach for interpreting blackbox models, and shown how it can be used to interpret a variety of different kinds of models. Important directions for future work include devising algorithms for model extraction using more expressive input distributions, and developing new ways to gain insight from the extracted decision trees.