Cascading-Decision-Tree
Open-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning
view repo
Deep Reinforcement Learning (DRL) has recently achieved significant advances in various domains. However, explaining the policy of RL agents still remains an open problem due to several factors, one being the complexity of explaining neural networks decisions. Recently, a group of works have used decision-tree-based models to learn explainable policies. Soft decision trees (SDTs) and discretized differentiable decision trees (DDTs) have been demonstrated to achieve both good performance and share the benefit of having explainable policies. In this work, we further improve the results for tree-based explainable RL in both performance and explainability. Our proposal, Cascading Decision Trees (CDTs) apply representation learning on the decision path to allow richer expressivity. Empirical results show that in both situations, where CDTs are used as policy function approximators or as imitation learners to explain black-box policies, CDTs can achieve better performances with more succinct and explainable models than SDTs. As a second contribution our study reveals limitations of explaining black-box policies via imitation learning with tree-based explainable models, due to its inherent instability.
READ FULL TEXT VIEW PDFOpen-source code for paper CDT: Cascading Decision Trees for Explainable Reinforcement Learning
Explainable Artificial Intelligence (XAI), especially Explainable Reinforcement Learning (XRL)
(puiutta2020explainable) is attracting more attention recently. How to interpret the action choices in reinforcement learning (RL) policies remains a critical challenge, especially as the gradually increasing trend of applying RL in various domains involving transparency and safety (cheng2019end, junges2016safety). Currently, many state-of-the-art DRL agents use neural networks (NNs) as their function approximators. While NNs are considered stronger function approximators (for better performances), RL agents built on top of them are generally lack of interpretability (lipton2018mythos). Indeed, interpreting the behavior of NNs themselves remains an open problem in the field (montavon2018methods, albawi2017understanding).In contrast, traditional DTs (with hard decision boundaries) are usually regarded as models with readable interpretations for humans, since humans can interpret the decision making process by visualizing the decision path. However, DTs may suffer from weak expressivity and therefore low accuracy. An early approach to reduce the hardness of DT was the soft/fuzzy DT (shorten as SDT) proposed by suarez1999globally. Recently, differentiable SDTs (frosst2017distilling) have shown both improved interpretability and better function approximation, which lie in the middle of traditional DTs and neural networks.
People have adopted differentiable DTs for interpreting RL policies in two slightly different settings: an imitation learning setting (coppens2019distilling, liu2018toward), in which imitators with interpretable models are learned from RL agents with black-box models, or a full RL setting (silva2019optimization), where the policy is directly represented as an interpretable model, e.g., DT. However, the DTs in these methods only conduct partitions in raw feature spaces without representation learning that could lead to complicated combinations of partitions, possibly hindering both model interpretability and scalability. Even worse, some methods have axis-aligned partitions (univariate decision nodes) (wu2017beyond, silva2019optimization) with much lower model expressivity.
In this paper, we propose Cascading Decision Trees (CDTs) striking a balance between model interpretability and accuracy, this is, having an adequate representation learning based on interpretable models (e.g. linear models). Our experiments show that CDTs share the benefits of having a significantly smaller number of parameters (and a more compact tree structure) and better performance than related works. The experiments are conducted on RL tasks, in either imitation-learning or RL settings. We also demonstrate that the imitation-learning approach is less reliable for interpreting the RL policies with DTs, since the imitating DTs may be prominently different in several runs, which also leads to divergent feature importances and tree structures. Our code for algorithm implementation and experiments is released at: https://github.com/quantumiracle/Cascading-Decision-Tree.
A series of works were developed in the past two decades along the direction of differentiable DTs (irsoy2012soft, laptev2014convolutional). Recently, frosst2017distilling
proposed to distill a SDT from a neural network. Their approach was only tested on MNIST digit classification tasks.
wu2017beyond further proposed the tree regularization technique to favor the models with decision boundaries closer to compact DTs for achieving interpretability. To further boost the prediction accuracy of tree-based models, two main extensions based on single SDT were proposed: (1) ensemble of trees, or (2) unification of NNs and DTs.An ensemble of decision trees is a common technique used for increasing accuracy or robustness of prediction, which can be incorporated in SDTs (rota2014neural, kontschieder2015deep, kumar2016ensemble), giving rise to neural decision forests. Since more than one tree needs to be considered during the inference process, this might yield complications in the interpretability. A common solution is to transform the decision forests into a single tree (sagi2020explainable).
As for the unification of NNs and DTs, laptev2014convolutional propose convolutional decision trees for feature learning from images. Adaptive Neural Trees (ANTs) (tanno2018adaptive)
incorporate representation learning in decision nodes of a differentiable tree with nonlinear transformations like convolutional neural networks (CNNs). The nonlinear transformations of an ANT, not only in routing functions on its decision nodes but also in feature spaces, guarantee the prediction performances in classification tasks on the one hand, but also hinder the potential of interpretability of such methods on the other hand.
wan2020nbdtpropose the neural-backed decision tree (NBDT) which transfers the final fully connected layer of a NN into a DT with induced hierarchies for the ease of interpretation, but shares the convolutional backbones with normal deep NNs, yielding the state-of-the-art performances on CIFAR10 and ImageNet classification tasks.
However, these advanced methods either employ multiple trees with multiplicative numbers of model parameters, or heavily incorporate deep learning models like CNNs in the DTs. Their interpretability is severely hindered due to their model complexity.
To interpret an RL agent, coppens2019distilling propose distilling the RL policy into a differentiable DT by imitating a pre-trained policy. Similarly, liu2018toward apply an imitation learning framework but to the value function of the RL agent. They also propose Linear Model U-trees (LMUTs) which allow linear models in leaf nodes. silva2019optimization propose to apply differentiable DTs directly as function approximators for either function or the policy in RL. They apply a discretization process and a rule list tree structure to simplify the trees for improving interpretability. The VIPER method proposed by bastani2018verifiable also distills policy as NNs into a DT policy with theoretically verifiable capability, but for imitation learning settings and nonparametric DTs only.
Our proposed CDT is distinguished from other main categories of methods with differentiable DTs for XRL in the following ways: (i) Compared with SDT (frosst2017distilling), partitions in CDT not only happen in original input space, but also in transformed spaces by leveraging intermediate features. This is well documented in recent works (kontschieder2015deep, xiao2017ndt, tanno2018adaptive) to improve model capacity, and it can be further extended into hierarchical representation learning with advanced feature learning modules like CNN (tanno2018adaptive). (ii) Compared with work by coppens2019distilling, space partitions are not limited to axis-aligned ones (which hinders the expressivity of trees with certain depths), but achieved with linear models of features as the routing functions. Moreover, the adopted linear models are not a restriction (but as an example) and other interpretable transformations are also allowed in our CDT method. (iii) Compared with ANTs (tanno2018adaptive), our CDT method unifies the decision making process based on different intermediate features with a single decision making tree, which follows the low-rank decomposition of a large matrix with linear models. It thus greatly improves the model simplicity for achieving interpretability. About model simplicity and interpretability in DTs, see our motivating example in Appendix A.
A SDT is a differentiable DT with a probabilistic decision boundary at each node. Considering we have a DT of depth
, each node in the SDT can be represented as a weight vector (with the bias as an additional dimension)
, where and indicate the index of the layer and the index of the node in that layer respectively, as shown in Fig. 1. The corresponding node is represented as , where uniquely indices the node.The decision path for a single instance can be represented as set of nodes , where is the set for all nodes on the tree. We have , where
is the probability of going from node
to . Note that will always be 1 for a hard DT (safavian1991survey). Therefore the path probability to a specific node is: . In the following, we name all DTs using probabilistic decision path as SDT-based methods, shorten as SDT.silva2019optimization propose to discretize the learned differentiable SDTs into univariate DTs for improving interpretability. Specifically, for a decision node with a -dimensional vector (the first dimension is the bias term), the discretization process (i) selects the index of largest weight dimension as and (ii) divides by , to construct a univariate hard DT. Without further description, the default discretization process in our experiments for both SDTs and CDTs also follows this manner.
Methods. We propose CDT^{1}^{1}1
A motivation of CDT method is the duplicative structures in the heuristic solution of
LunarLander-v2, as discussed in Appendix B as an extension based on SDT, allowing it to have the capability of representation learning as well as decision making in transformed spaces. In a simple CDT architecture as shown on the left of Fig. 2, a feature learning tree is cascaded with a decision making tree . In tree , each decision node is a simple function of raw feature vector given learnable parameters : , while each leaf of it is a feature representation function: parameterized by . In tree , each decision node is a simple function of learned features rather than raw features given learnable parameters : . The output distribution of is another parameterized function independent of either or . For simplicity and interpretability, all functions are linear functions in our examples, but they are free to be extended with other interpretable models.Specifically, we provide detailed mathematical relationships based on linear functions as follows. For an environment with input state vector and output discrete action dimension , suppose that our CDT has intermediate features of dimension (not the number of leaf nodes on , but for each leaf node), we have the probability of going to the left/right path on the -th node on :
(1) |
which is the same as in SDTs. Then we have the linear feature representation function for each leaf node on , which transforms the basis of the representation space with:
(2) |
which gives the -dimensional intermediate feature vector for each possible path. In tree , it is also a SDT but with raw input replaced by learned representations for each node in :
(3) |
Finally, the output distribution is feature-independent, which gives the probability mass values across output dimension for each leaf of as:
(4) |
Suppose we have a CDT of depth for and depth for , the probability of going from root of either or to -th leaf node on each sub-tree both satisfies previous derivation in SDTs: , where is the set of nodes on path. Therefore the overall path probability of starting from the root of to -th leaf node of and then -th leaf node of is:
(5) |
Each leaf of the feature learning tree represents one possible assignment for intermediate feature values, while they share the subsequent decision making tree. During the inference process, we simply take the leaf on or with the largest probability to assign values for intermediate features (in ) or derive output probability (in ), which may sacrifice little accuracy but increase interpretability. The detailed architecture of CDT with relationships among variables is plotted in figures in Appendix C.
Model Simplicity
We analyze the simplicity of CDT compared with SDT in terms of the numbers of learnable parameters in the model. The reason for doing this is that in order to increase the interpretability, we need to simplify the tree structure or reduce the number of parameters including weights and bias in the tree.
We can analyze the model simplicity of CDT against a normal SDT with linear functions in a matrix decomposition perspective. Suppose we need a total of multivariate decision nodes in the -dimensional raw input space to successfully partition the space for high-performance prediction, which can be written as a matrix . CDT tries to achieve the same partitions through learning a transformation matrix and a partition matrix in the -dimensional feature space , such that:
(6) | ||||
(7) |
Therefore the number of model parameters to be learned with CDT is reduced by compared against a standard SDT of the same total depth, and it is a positive value as long as , while keeping the model expressivity.
A detailed quantitative analysis of model parameters for CDT and SDT is provided in Appendix D.
Hierarchical CDT
From above, a simple CDT architecture as in Fig. 2 with a single feature learning model and single decision making model can achieve intermediate feature learning with a significant reduction in model complexity compared with traditional SDT. However, sometimes the intermediate features learned with may be unsatisfying for capturing complex structures in advanced tasks, therefore we further extend the simple CDT architecture into more hierarchical ones. As shown on the right side in Fig. 2, two potential types of hierarchical CDT are displayed: (a) a hierarchical feature abstraction module with three feature learning models in a cascading manner before inputting to the decision module
; (b) a parallel feature extraction module with two feature learning models
before concatenating all learned features into .One needs to bear in mind that whenever the model structures are complicating, the interpretability of the model decreases due to the loss of simplicity. Therefore we did not apply the hierarchical CDTs in our experiments for maintaining interpretability. However, the hierarchical structure is one of the most preferred ways to keep simplicity as much as possible if trying to increase the model capacity and prediction accuracy, so it can be applied when necessary.
We compare CDT and SDT on two settings for interpreting RL agents: (1) the imitation learning setting, whereas the RL agent with a black-box model (e.g. neural network) to interpret first generates a state-action dataset for imitators to learn from, and the interpretation is derived on the imitators; (2) the full RL setting, whereas the RL agent is directly trained with the policy represented with interpretable models like CDTs or SDTs, such that the interpretation can be derived by directly spying into those models. The environments are CartPole-v1, LunarLander-v2 and MountainCar-v0 in OpenAI Gym (brockman2016openai). The depth of CDT is represented as "" in the following sections, where is the depth of feature learning tree and is the depth of decision making tree . Each setting is trained for five runs in imitation learning and three runs in RL.
Both the fidelity and stability of mimic models reflect the reliability of them as interpretable models. Fidelity is the accuracy of the mimic model, w.r.t.
the original model. It is an estimation of similarity between the mimic model and the original one in terms of prediction results. However, fidelity is not sufficient for reliable interpretations. An unstable family of mimic models will lead to inconsistent explanations of original black-box models. The stability of the mimic model is a deeper excavation into the model itself and comparisons among several runs. Previous research
(bastani2017interpreting) has investigated the fidelity and stability of decision trees as mimic models, where the stability is estimated with the fraction of equivalent nodes in different random decision trees trained under the same settings. In our experiments, the stability analysis is conducted via comparing tree weights of different instances in imitation learning settings.Performance. The datasets for imitation learning are generated with heuristic agents for environments CartPole-v1 and LunarLander-v2, containing 10000 episodes of state-action data for each environments. See Appendix E for other training details. The results are provided in Table 1 and 2.
Tree Type | Depth | Discretized | Accuracy (%) | Episode Reward | # of Parameters |
---|---|---|---|---|---|
SDT | 2 | ✗ | 94.1 | 500.0 | 23 |
✓ | 49.7 | 39.9 | 14 | ||
3 | ✗ | 94.5 | 500.0 | 51 | |
✓ | 50.0 | 42.5 | 30 | ||
4 | ✗ | 94.3 | 500.0 | 107 | |
✓ | 50.1 | 40.4 | 62 | ||
CDT (ours) | 1+2 | ✗ | 95.4 | 500.0 | 38 |
only | 94.4 | 500.0 | 35 | ||
only | 84.1 | 500.0 | 35 | ||
83.8 | 497.8 | 32 | |||
2+1 | ✗ | 95.6 | 500.0 | 54 | |
only | 92.7 | 500.0 | 45 | ||
only | 88.4 | 500.0 | 53 | ||
89.0 | 500.0 | 44 | |||
2+2 | ✗ | 96.6 | 500.0 | 64 | |
only | 91.6 | 500.0 | 55 | ||
only | 82.9 | 494.8 | 61 | ||
81.9 | 488.8 | 52 |
Tree Type | Depth | Discretized | Accuracy (%) | Episode Reward | # of Parameters |
---|---|---|---|---|---|
SDT | 4 | ✗ | 85.4 | 58.2 | 199 |
✓ | 54.8 | -237.1 | 94 | ||
5 | ✗ | 87.6 | 191.3 | 407 | |
✓ | 51.6 | -93.7 | 190 | ||
6 | ✗ | 88.7 | 193.4 | 823 | |
✓ | 60.2 | -172.4 | 382 | ||
7 | ✗ | 88.9 | 194.2 | 1655 | |
✓ | 62.7 | -233.4 | 766 | ||
CDT (ours) | 2+2 | ✗ | 88.2 | 107.4 | 116 |
only | 78.0 | -126.9 | 95 | ||
only | 68.3 | -301.6 | 113 | ||
64.4 | -229.7 | 92 | |||
2+3 | ✗ | 88.3 | 168.5 | 144 | |
only | 70.2 | -9.7 | 123 | ||
only | 40.7 | -106.3 | 137 | ||
35.9 | -130.2 | 116 | |||
3+2 | ✗ | 90.4 | 199.5 | 216 | |
only | 72.2 | -14.2 | 167 | ||
only | 78.1 | 150.8 | 209 | ||
64.6 | 7.1 | 160 | |||
3+3 | ✗ | 90.4 | 173.0 | 244 | |
only | 72.0 | -55.3 | 195 | ||
only | 58.7 | -91.5 | 237 | ||
46.8 | -210.5 | 188 |
CDTs perform consistently better than SDTs before and after discretization process in terms of prediction accuracy, with different depths of the tree. Additionally, for providing a similarly accurate model, CDT method always has a much smaller number of parameters compared with SDT, which improves its interpretability as shown in later sections. However, although better than SDTs, CDTs also suffer from degradation in performance after discretization, which could lead to unstable and unexpected models. We claim that this is a general drawback for tree-based methods with soft decision boundaries in XRL with imitation-learning settings, which is further studied in the following.
Stability. To investigate the stability of imitation learners for interpreting the original agents, we measure the normalized weight vectors from different imitation-learning trees. For SDTs, the weight vectors are the linear weights on inner nodes, while for CDTs are considered. Through the experiments, we would like to show how unstable the imitators are. We have a tree agent , where is another imitator tree agent trained under the same setting, is a random tree agent, and is a heuristic tree agent (used for generating the training dataset). The distances of tree weights between two agents are calculated with the following formula:
(8) |
while are averaged over all possible s and s with the same setting. Since we have the heuristic agent for LunarLander-v2 environment and we transform the heuristic agent into a multivariate DT agent, we get the decision boundaries of the tree on all its nodes. So we also compare the differences of decision boundaries in heuristic tree agent and those of the learned tree agent . But we do not have the official heuristic agent for CartPole-v1 in the form of a decision tree. For the decision making trees in CDTs, we transform the weights back into the input feature space to make a fair comparison with SDT and the heuristic tree agent. The results are displayed in Table 3, all trees use intermediate features of dimension 2 for both environments. In terms of stability, CDTs generally perform similarly as SDTs and even better on CartPole-v1 environment.
Tree Type | Env | Depth | |||
SDT | CartPole-v1 | 3 | 0.21 | 0.90 | - |
LunarLander-v2 | 4 | 0.50 | 0.92 | 0.84 | |
CDT (ours) | CartPole-v1 | 1+2 | 0.07 | 1.05 | - |
2+2 | 0.19 | 1.03 | - | ||
LunarLander-v2 | 2+2 | 0.63 | 1.01 | 0.98 | |
3+3 | 0.53 | 0.83 | 0.86 |
. CDTs are generally more stable, but still with large variances over different imitators.
We further evaluate the feature importance with at least two different methods on SDTs to demonstrate the instability of imitation learning settings for XRL, see Appendix F.1. We also display all trees (CDTs and SDTs) for both environments in Appendix F.3. Significant differences can be found in different runs for the same tree structure with the same training setting, which testifies the unstable and unrepeatable nature by interpreting imitators instead of the original agents.
Conclusion. We claim that the current imitation-learning setting with tree-based models is not suitable for interpreting the original RL agent, with the following evidence derived from our experiments: (i) The discretization process usually degrades the performance (prediction accuracy) of the agent significantly, especially for SDTs. Although CDTs alleviate the problem to a certain extent, the performance degradation is still not negligible, therefore the imitators are not expected to be alternatives for interpreting the original agents; (ii) With the stability analysis in our experiments, we find that different imitators will display different tree structures even if they follow the same training setting on the same dataset, which leads to significantly different decision paths and local feature importance assignments.
Performance. We evaluate the learning performances of different DTs and NNs as policy function approximators in RL, as shown in Fig. 3. Every setting is trained for three runs. We use Proximal Policy Optimization (schulman2017proximal)
algorithm in our experiments. The multilayer perceptron (MLP) model is a two-layer NN with 128 hidden units. The SDT has a depth of 3 for
CartPole-v1 and 4 for LunarLander-v2. The CDT has depths of 2 and 2 for feature learning tree and decision making tree respectively on CartPole-v1, while with depths of 3 and 3 for LunarLander-v2. Therefore for each environment, the SDTs and CDTs have a similar number of model parameters, while MLP model has at least 6 times more parameters. Detailed training settings are provided in Appendix G. From Fig. 3, we can see that CDTs can at least outperform SDTs as policy function approximators for RL in terms of both sampling efficiency and final performance, although may not learn as fast as general MLPs with a significantly larger number of parameters. For MountainCar-v0 environment, the learning performances are less stable due to the sparse reward signals and large variances in exploration. However, with CDT for policy function approximation, there are still near-optimal agents after training with or without state normalization, as displayed in Appendix H.Tree Depth. The depths of DTs are also investigated for both SDT and CDT, because deeper trees tend to have more model parameters and therefore lay more stress on the accuracy rather than interpretability. Fig. 4 shows the learning curves of SDTs and CDTs in RL with different tree depths for the two environments, using normalized states as input, while the comparison with unnormalized states is in Appendix H with similar results. From the comparisons, we can see that generally deeper trees can learn faster with even better final performances for both CDTs and SDTs, but CDTs are less sensitive to tree depth than SDTs.
Interpretability. We display the learned CDTs in RL settings for three environments, compared against some heuristic solutions or SDTs. A heuristic solution^{2}^{2}2Provided by Zhiqing Xiao on OpenAI Gym Leaderboard: https://github.com/openai/gym/wiki/Leaderboard for CartPole-v1 is: if , push right; otherwise, push left. As shown in Fig. 5, in our learned CDT of depth 1+2, the weights of two-dimensional intermediate features ( and ) are much larger on the last two dimensions of observation than the first two, therefore we can approximately ignore the first two dimensions due to their low importance in decision making process. So we get similar intermediate features for two cases in two dimensions, which are approximately after normalization . Based on the decision making tree in learned CDT, it gives a close solution as the heuristic one, yielding if push left otherwise push right. The original CDT before discretization and a SDT for comparison are provided in Appendix I.
For MountainCar-v0, due to the complexity in the landscape as shown in Fig. 6, interpreting the learned model is even harder. However, through CDT, we can see that the agent learns intermediate features as combinations of car position and velocity, potentially being an estimated future position or previous position, and makes action decisions based on that. The original CDT before discretization has depth 2+2 with one-dimensional intermediate features, and its structure is shown in Appendix I.
Due to the page limitation, the learned CDT as an example for LunarLander-v2 is provided in Appendix I, which also captures some important feature combinations like the angle with angular speed and X-Y coordinate relationships for decision making.
In this work, we have proposed a new architecture of differentiable DT, the Cascading Decision Tree (CDT). A simple CDT cascades a feature learning DT and a decision making DT into a single model. From our experiments, we show that compared with traditional differentiable DTs (i.e.
, DDTs or SDTs) CDTs have better function approximation in both imitation learning and full RL settings with a significantly reduced number of model parameters while better preserving the tree prediction accuracy after discretization. We also qualitatively and quantitively corroborate that the SDT-based methods with imitation learning setting may not be proper for achieving interpretable RL agents due to instability among different imitators in their tree structures, even when having similar performances. Finally, we contrast the interpretability of learned DTs in RL settings, especially for the intermediate features. Our analysis supports that CDTs lend themselves to be further extended to hierarchical architectures with more interpretable modules, due to its ts richer expressivity allowed via representation learning. More work needs to be done to fully realize the potential of our method, which involves the investigation of hierarchical CDT settings and well-regularized intermediate features for further interpretability. Additionally, since the present experiments are demonstrated with linear transformations in the feature space, non-linear transformations are expected to be leveraged for tasks with higher complexity or continuous action space while preserving interpretability.
People have proposed a variety of desiderata for interpretability (lipton2018mythos), including trust, causality, transferability, informativeness, etc. Here we summarize the answers in general into two aspects: (1) interpretable meta-variables that can be directly understood; (2) model simplicity. Understandable variables with simple model structures comprise most of the models interpreted by humans either by physical and mathematical principles or human intuitions, which is also in accordance with the Occam’s razor principle.
For model simplicity, a simple model in most cases is more interpretable than a complicated one. Different metrics can be applied to measure the model complexity (murray2007reducing, molnar2019quantifying), like the number of model parameters, model capacity, computational complexity, non-linearity, etc. There are ways to reduce the model complexity: model projection from a large space into a small sub-space, merging the replicates in the model, etc. Feature importance (schwab2019cxplain) (e.g., through estimating model sensitivity to changes of inputs) is one type of methods for projecting a complicated model into a scalar space across feature dimensions. The proposed method CDT in this paper is a way to improve model simplicity by merging the replicates through representation learning.
For a binary classification problem, three different tree structures and their decision boundaries are compared in Fig. 7: (1) multivariate DT; (2) univariate DT; (3) differentiable rule lists (silva2019optimization). We need to define the simplicity of DT for achieving interpretability and choose which type of tree is the one we prefer.
For the first two structures, we may not be able to draw conclusions for their simplicity since it seems one has a simpler tree structure but more complex decision boundaries while the other one is the opposite. We will have another example to clarify it. But we can draw a conclusion for the second and the third ones since the structure of differentiable rule lists is simpler, as it has an asymmetric structure and the left nodes are always leaves. However, the problem of differentiable rule lists is also obvious, that it sacrifices the model capacity and therefore hurts the accuracy. For the first left node in the example, it can only choose either one of the two labels without distinctiveness, which is clearly not correct.
To clarify the problem of choosing between the first two structures, a more complicated example is provided in Fig. 8. It shows the comparison of a multivariate DT and a univariate DT for a binary classification task. Apparently, the multivariate DT is simpler than the univariate one in its structure. The conclusion is that for complex cases, the multivariate tree structure has greater potentials of achieving necessary space partitioning with simpler model structures.
As shown in Fig. 9, we found that the heuristic solution^{3}^{3}3In the code repository of OpenAI Gym: https://github.com/openai/gym/blob/master/gym/envs/box2d/lunar_lander.py for LunarLander-v2 contains the duplicative structure after being transformed into a decision tree, and the duplicative structure can be leveraged to simplify the learning models. Specifically, the two green modules (feature learning ones) in the tree are basically assigning different values to two intermediate variables ( and ) under different cases, while the grey module (decision making one) takes the intermediate variables to make action selection. Both the decision making module and the second feature learning module are used repeatedly on different branches on the tree, which forms a duplicative structure. This can help with the simplicity and interpretability of the model, which motivates our idea of CDT methods for XRL.
Considering the case where we have a raw feature dimension of inputs as , we choose the intermediate feature dimension to be . A CDT with two cascading trees of depth and and a SDT with depth are compared. Supposing the output dimension is , we can derive the number of parameters in CDT as:
(9) |
while the number of parameters in SDT is:
(10) |
Considering an example for Eq. 9 and Eq. 10 with SDT being depth of 5 while CDT has , raw feature dimension , intermediate feature dimension , and output dimension , we can get and . It indicates a reduction of around parameters in this case, which will significantly increase interpretability.
In another example, when , the numbers of parameters in CDT or SDT models are compared in Fig. 11, assuming for a total depth of range 2 to 20. The Ratio of numbers of model parameters is derived with: .
Tree Type | Env | Hyperparameter | Value |
Common | CartPole-v1 | learning rate | |
batch size | 1280 | ||
epochs | 80 | ||
LunarLander-v2 | learning rate | ||
batch size | 1280 | ||
epochs | 80 | ||
SDT | CartPole-v1 | depth | 3 |
LunarLander-v2 | depth | 4 | |
CDT | CartPole-v1 | FL depth | 2 |
DM depth | 2 | ||
# intermediate variables | 2 | ||
LunarLander-v2 | FL depth | 3 | |
DM depth | 3 | ||
# intermediate variables | 2 |
Both the fidelity and stability of mimic models reflect the reliability of them as interpretable models. Fidelity is the accuracy of the mimic model, w.r.t. the original model. It is an estimation of similarity between the mimic model and the original one in terms of prediction results. However, fidelity is not sufficient for reliable interpretations. An unstable family of mimic models will lead to inconsistent explanations of original black-box models. The stability of the mimic model is a deeper excavation into the model itself and comparisons among several runs. Previous research (bastani2017interpreting) has investigated the fidelity and stability of decision trees as mimic models, where the stability is estimated with the fraction of equivalent nodes in different random decision trees trained under the same settings. However, in our tests, apart from evaluating the tree weights in different imitators, we also use the feature importance given by different differentiable DT instances with the same architecture and training setting to measure the stability.
For differentiable DT methods (e.g. CDT and SDT), since the decision boundaries within each node are linear combinations of features, we can simply take the weight vector as the importance assignment for those features within each node.
After training the tree, a local explanation is straightforward to derive with the inference process of a single instance and the decision path on the tree. A global explanation can be the average local explanation across instances, e.g. in an episode or several episodes under the RL settings. Here we list several ways of assigning importance values for input features with SDT, to derive the feature importance vector with the same dimension as the decision node vectors and input feature:
For local explanation:
I. A trivial way of feature importance assignment on SDT would be simply adding up all weight vectors of nodes on the decision path:
II. The second way is a weighted average of the decision vectors, w.r.t. the confidence of the decision boundaries for a specific instance. Considering the soft decision boundary on each node, we assume that the more confident the boundary is applied to partition the data point into a specific region within the space, the more reliable we can assign feature importance according to the boundary. The confidence of a decision boundary can be positively correlated with the distance from the data point to the boundary, or the probability of the data point falling into one side of the boundary. The latter one is straightforward in our settings. We define the confidence as , which is also the probability of choosing node in -th layer from its parent on instance ’s decision path. It indicates how far the data point is from the middle of the soft boundary in a probabilistic view. Therefore the importance value is derived via multiplying the confidence value with each decision node vector: .
Fig. 12 helps to demonstrate the reason for using the decision confidence (i.e., probability) as a weight for assigning feature importance, which indicates that the probability of belonging to one category is positively correlated with the distance from the instance to the decision boundary. Therefore when there are multiple boundaries for partitioning the space (e.g., two in the figure), we assign the boundaries having a shorter distance to the data point with smaller confidence in determining feature importance, since based on the closer boundaries the data point is much easier to be perturbed into the contrary category and less confident to remain in the original.
III. Since the tree we use is differentiable, we can also apply gradient-based methods for feature importance assignment, which is: , where .
For global explanation:
We can simply average the feature importance at each time step (i.e., local explanation) to get global feature importance over an episode or across episodes, where the local explanations can be derived in either of the above ways.
To testify the stability of applying SDT method with imitation learning from a given agent, we compare the SDT agents of different runs and original agents using certain metrics. The agent to be imitated from is a heuristic decision tree (HDT) agent, and the metric for evaluation is the assigned feature importance across an episode on each feature dimension. As described in the previous section, the feature importance for local explanation can be achieved in three ways, which work for both HDT and SDT here. The environment is LunarLander-v2 with an 8-dimensional observation in our experiments here.
Considering SDT of different runs may predict different actions, even if they are trained with the same setting and for a considerable time to achieve similarly high accuracies, we conduct comparisons not only for an online decision process during one episode, but also on a pre-collected offline state dataset by the HDT agent. We hope this can alleviate the accumulating differences in trajectories caused by consecutively different actions made by different agents, and give a more fair comparison on the decision process (or feature importance) for the same trajectory.
Different Tree Depths. First, the comparison of feature importance (adding up node weights on decision path) for HDT and the learned SDT of different depths in an online decision episode is shown as Fig. 13. All SDT agents are trained for 40 epochs to convergence. The accuracies of three trees are , respectively.
From Fig. 13 we can tell significant differences among SDTs with different depths, as well as in comparing them against the HDT even on the episode with the same random seed, which indicates that the depth of SDT will not only affect the model prediction accuracy but also the decision making process.
Same Tree with Different Runs. We compare the feature importance on an offline dataset, containing the states of the HDT agent encounters in one episode. All SDT agents have a depth of 5 and are trained for 80 epochs to convergence. The three agents have testing accuracies of respectively after training. The feature importance values are evaluated with different approaches as mentioned above (local explanation I, II and III) on the same offline episode, as shown in Fig 14. In the results, local explanation II and III looks similar, since most decision nodes in the decision path with greatest probability have the probability values close to 1 (i.e. close to a hard decision boundary) when going to the child nodes.
From Fig. 14, considerable differences can also be spotted in different runs for local explanations, even though the SDTs have similar prediction accuracies, no matter which metric is applied.
We display the agents trained with CDTs and SDTs on both CartPole-v1 and LunarLander-v2 before and after tree discretization in this section, as in Fig. 15, 16, 17, 18, 20, 21, 22. Each figure contains trees trained in four runs with the same setting. Each sub-figure contains one learned tree (plus an input example and its output) with an inference path (i.e., the solid lines) for the same input instance. The lines and arrows indicate the connections among tree nodes. The colors of the squares on tree nodes show the values of weight vectors for each node. For feature learning trees in CDTs, the leaf nodes are colored with the feature coefficients. The output leaf nodes of both SDTs and decision making trees in CDTs are colored with the output categorical distributions. Three color bars are displayed on the left side for inputs, tree inner nodes, and output leaves respectively, as demonstrated in Fig. 15. It remains the same for the rest tree plots. The digits on top of each node represent the output action categories.
Among all the learned tree structures, significant differences can be told from weight vectors, as well as intermediate features in CDTs, even if the four trees are under the same training setting. This will lead to considerably different explanations or feature importance assignments on trees.
Tree Type | Env | Hyperparameter | Value |
Common | CartPole-v1 | learning rate | |
0.98 | |||
0.95 | |||
0.1 | |||
update iteration | 3 | ||
hidden dimension (value) | 128 | ||
episodes | 3000 | ||
time horizon | 1000 | ||
LunarLander-v2 | learning rate | ||
0.98 | |||
0.95 | |||
0.1 | |||
update iteration | 3 | ||
hidden dimension (value) | 128 | ||
episodes | 5000 | ||
time horizon | 1000 | ||
MountainCar-v0 | learning rate | ||
0.999 | |||
0.98 | |||
0.1 | |||
update iteration | 10 | ||
hidden dimension (value) | 32 | ||
episodes | 5000 | ||
time horizon | 1000 | ||
MLP | CartPole-v1 | hidden dimension (policy) | 128 |
LunarLander-v2 | hidden dimension (policy) | 128 | |
MountainCar-v0 | hidden dimension (policy) | 32 | |
SDT | CartPole-v1 | depth | 3 |
LunarLander-v2 | depth | 4 | |
MountainCar-v0 | depth | 3 | |
CDT | CartPole-v1 | FL depth | 2 |
DM depth | 2 | ||
# intermediate variables | 2 | ||
LunarLander-v2 | FL depth | 3 | |
DM depth | 3 | ||
# intermediate variables | 2 | ||
MountainCar-v0 | FL depth | 2 | |
DM depth | 2 | ||
# intermediate variables | 1 |
To normalize the states^{4}^{4}4We found that sometimes the state normalization can affect the learning performances significantly, especially in RL settings.
, we collect 3000 episodes of samples for each environment with a well-trained policy and calculate its mean and standard deviation. During training, each state input is subtracted by the mean and divided by the standard deviation.
The hyperparameters for RL are provided in Table 5 for MLP, SDT, and CDT on three environments.
Fig. 23 displays the comparison of learning curves for SDTs and CDTs with different depths, under the RL settings without state normalization. The results are similar as those with state normalization in the main paragraph.
Fig. 24 shows the comparison of MLP, SDT, and CDT as policy function approximators in RL for the MountainCar-v0 environment, where the learning curves for each run, as well as their means and standard deviations, are displayed. The MLP model has two layers with 32 hidden units. The depth of SDT is 3. CDT has depths 2 and 2 for the feature learning tree and decision making tree respectively, with the dimension of the intermediate feature as 1. The training results are less stable due to large variances in exploration, but CDTs generally perform better than SDTs with near-optimal agents learned considering both cases.