Interpretable Local Tree Surrogate Policies

by   John Mern, et al.

High-dimensional policies, such as those represented by neural networks, cannot be reasonably interpreted by humans. This lack of interpretability reduces the trust users have in policy behavior, limiting their use to low-impact tasks such as video games. Unfortunately, many methods rely on neural network representations for effective learning. In this work, we propose a method to build predictable policy trees as surrogates for policies such as neural networks. The policy trees are easily human interpretable and provide quantitative predictions of future behavior. We demonstrate the performance of this approach on several simulated tasks.



There are no comments yet.


page 1

page 2

page 3

page 4


Programmatically Interpretable Reinforcement Learning

We study the problem of generating interpretable and verifiable policies...

Programmatic Policy Extraction by Iterative Local Search

Reinforcement learning policies are often represented by neural networks...

Understanding Finite-State Representations of Recurrent Policy Networks

We introduce an approach for understanding finite-state machine (FSM) re...

PoliFi: Airtime Policy Enforcement for WiFi

As WiFi grows ever more popular, airtime contention becomes an increasin...

Learning Interpretable, High-Performing Policies for Continuous Control Problems

Gradient-based approaches in reinforcement learning (RL) have achieved t...

Interpretable Machine Learning for Resource Allocation with Application to Ventilator Triage

Rationing of healthcare resources is a challenging decision that policy ...

Interpreting Shared Deep Learning Models via Explicable Boundary Trees

Despite outperforming the human in many tasks, deep neural network model...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.


Deep reinforcement learning has achieved state of the art performance in several challenging task domains 

mnih2015. Much of that performance comes from the use of highly expressive neural networks to represent policies. Humans are generally unable to meaningfully interpret neural network parameters, which has lead to the common view of neural networks as “black-box” functions. Poor interpretability often leads to a lack of trust in neural networks and use of other more transparent, though potentially less high-performing policy models. This is often referred to as the performance-transparency trade-off.

Interpretability is important for high-consequence tasks. Domains in which neural networks have already been applied, such as image classification, often do not require interpretable decisions because the consequences of mislabeling images are typically low. In many potential applications of deep reinforcement learning, such as autonomous vehicle control, erroneous actions may be costly and dangerous. In these cases, greater trust in policy decisions is typically desired before systems are deployed.

Our work is motivated by tasks in which a human interacts directly with the policy, either by approving agent actions before they are taken or by enacting recommendations directly. This is often referred to having a human “in the loop”. An example is an automated cyber security incident response system that provides recommendations to a human analyst. In these cases, knowing the extended course of actions before committing to the recommendation can enhance the trust of the human operator.

It is difficult for humans to holistically interpret models with even a small number of interacting terms lipton2018. Neural networks commonly have several thousand parameters and non-linear interactions, making holistic interpretation infeasible. Some existing methods constrain neural networks architectures and train them to learn human interpretable features during task learning. The training and architecture constraints, however, can degrade performance compared to an unconstrained policy. A common approach is to learn transparent surrogates from the original models adadi2018. A major challenge in this approach is balancing the fidelity of the surrogate to the original with the interpretability of the surrogate. Surrogate models that are generated stochastically can additionally struggle to provide consistent representations for the same policy.

Figure 1:

Surrogate Policy Tree. The figure shows a surrogate policy tree for the vaccine planning task, a finite-horizon MDP with 8 discrete actions. Each node represents an action and the states that led to it. Node borders and edge widths are proportional to the probability of encountering that node or edge during policy execution. The most likely trajectory from root to leaf is shown in blue.

In this work, we propose a method to develop transparent surrogate models as local policy trees. The resulting trees encode an intuitive plan of future actions with high-fidelity to the original policy. The proposed approach allows users to specify tree size constraints and fidelity targets. The method is model-agnostic, meaning that it does not require the original policy to take any specific form. Though this work was motivated by neural networks, the proposed approach may be used with any baseline policy form.

During execution of a tree policy, the actions taken by an agent are guaranteed to be along one of the unique paths from the root. Experiments on a simple grid world task highlight the impact of algorithm parameters on tree behavior. The experiments also show that using the trees as a receding-horizon policy maintains good performance relative to the baseline model. Additional experiments on more complex infrastructure planning and cyber-physical security domains demonstrate the potential utility of this approach in real-world tasks.


Despite its importance in machine learning, there is no precise qualitative or mathematical definition of interpretability. In the literature, interpretability is commonly used to describe a model with one of two characteristics. The first is explainability, defined by 

miller2019 as the degree to which a human can understand the cause of a decision. The second characteristic is predictability, which is defined by kim2016 as the degree to which a human can consistently predict a model’s result. In this work, we will refer to models that are either explainable or predictable as interpretable.

Techniques for model interpretability can vary greatly. We characterize the methods presented in this work using the taxonomy proposed by adadi2018 as model-agnostic, local surrogate methods. Model-agnostic methods are those that can be may be applied to any model with the appropriate mapping from input to output. Surrogate modeling techniques distill the behavior of a complex baseline model into a more transparent surrogate. Local methods provide interpretability that is valid only in the neighborhood of a target input point. As a result, they tend to maintain higher fidelity to the original model in the acceptable region than models that attempt to provide global interpretations.

Markov Decision Processes

The control tasks in this work are assumed to satisfy the Markov assumption and may be modeled as either Markov decision processes (MDPs) or partially observable Markov decision processes (POMDPs). MDPs are defined by tuples

, where and are the state and action spaces, respectively, is the transition model, is the reward function, and is a time discount factor. POMDPs are modeled by the same tuple with the addition of the observation space and observation model . A policy is a function that maps a state to an action in MDPs or a history of observations to an action in POMDPs. To solve a MDP is to learn the policy that maximizes the expected sum of discounted rewards


over all states. Learning an optimal policy is equivalent to learning the optimal action value function



is the optimal policy. When learning an action value function estimator, the effective policy is


where is the learned approximator. These problems can be solved using a variety of methods such as dynamic programming, Monte Carlo planning, and reinforcement learning kochenderfer2022. Similarly, various function types can be used to model the policy or value function estimator. Neural networks are commonly used as policies for complex tasks.

Related Work

There has been a wealth of prior work in the field of explainable and interpretable artificial intelligence. The book by 

molnar2019 provides an overview of model interpretability techniques for general machine learning methods. Methods for specific models and tasks have also been proposed. puiutta2020 provides a survey of recent work in explainable reinforcement learning. The discussion in this section is restricted to methods relevant to deep reinforcement learning. Methods requiring domain specific representations verma2018 are not considered.

A common technique for model-agnostic surrogate modeling is to learn a surrogate model of an inherently interpretable class. An useful example is locally interpretable model-agnostic explanations (LIME) ribeiro2016. LIME learns sparse linear representations of a target policy at a specific input by training on a data set of points near the target point. While the resulting linear functions are interpretable, they are often not consistent and small variations in the training data can result in drastically different linear functions.

Several methods have been proposed to distill neural network policies to decision tree surrogates. Linear model U-trees 

liu2018 and soft decision trees coppens2019 take similar approaches to tree representation and learning. Decision trees are global policy representations that map input observations to output actions by traversing a binary tree. Actions can be partially understood by inspecting the values at each internal node. Unfortunately, the understanding provided by decision trees can be limited because the mappings still pass the input observation through several layers of affine transforms and non-linear functions. Further, both methods empirically showed significant performance loss compared to the baseline policies.

Structural causal models madumal2020 also try to learn an inherently interpretable surrogate as a directed acyclic graph (DAG). The learned DAG represents relations between objects in the task environment that can be easily understood by humans. The DAG structure, however, must be provided for each task, and the fidelity is not assured.

Our work proposes tree representations that provide intuitive maps over future courses of action. When used as a policy, the realized course of action is guaranteed to be along the branches of the tree. In this way, the method is similar to methods of neural network verification that seek to provide guarantees of network outputs over a given set of inputs. liu2021 provides an overview of verification for general neural networks. Additional works have been proposed specifically for verification of neural network control policies sidrane2021.

The methods proposed in this work also resemble Monte Carlo tree search (MCTS) algorithms kocsis2006. MCTS methods search for an optimal action from a given initial state or belief by sampling trajectories using a search policy. The result of an MCTS search is a tree of trajectories reachable from the initial state. The tree trajectories, however, are not limited to those reached under a fixed policy, but rather by a non-stationary tree search policy. The resulting tree cannot be used to interpret or predict future policy behavior. Trees from planners using UCB exploration auer2002 tend to grow exponentially with search depth as .

Proposed Method

We present model-agnostic methods to represent policies for both MDPs and POMDPs as interpretable tree policies. Trees are inherently more interpretable than high dimensional models like neural networks. The proposed method uses the baseline policy to generate simulations of trajectories reachable from a given initial state. The trajectories are then clustered to develop a width-constrained tree that represents the original policy with good fidelity. The methods assume that baseline policies return distributions over actions or estimates of action values for a given state. Building trees also requires a generative model of the task environment.

In addition to representing the future policy, the tree also provides useful statistics on expected performance such as likelihood of following each represented trajectory. The tree may be used as a policy during task execution with guarantees on expected behavior. We present methods to generalize to states not seen during tree construction by using the tree as a constraint on the baseline policy.

Local Tree Policy

Figure 2: Detailed Tree Views. (a) The figure shows the first three layers of the vaccine surrogate policy tree. Each node provides its action, estimated value, and probability of reaching it. (b) The figure shows the first two tree layers with the mean of the observations leading to each action node displayed.

Before describing the tree construction algorithm, we first present the local tree policies that it produces. A policy tree is a rooted polytree where nodes represent actions taken during policy execution. An example tree policy for a stochastic, fully observable task is shown in fig. 1. Each node of the tree represents an action taken at time . The root action of each tree is the action recommended by the policy at the initial state. Each path from the tree root to a leaf node gives a trajectory of actions that the agent may take during policy execution up to some depth . Trajectories leading to terminal conditions before steps result in shallower tree branches.

Policy trees are interpretable representations of the future behavior of a policy that give explicit, quantitative predictions of future trajectories. Each node provides an estimate of the probability that the action sequence up to that node will be taken and estimates of the policy value . Each node also contains a set of example states or observations that would result in that action.

Figure 2 provides more detailed views of the policy tree in fig. 1. The trajectory probabilities and value estimates shown in fig. 1(a) are calculated during tree construction and may be presented for any surrogate policy. The states in the example problem

are real vectors. Each node in 

fig. 1(b) shows the mean of that node’s state values. While the mean state value is useful for understanding this problem, it may not be useful in all problems. For example, the mean value may be meaningless for problems with discrete state spaces. Methods to compactly represent a node’s set of states or observations cannot be generally defined and instead should be specified for each problem.

Trees are an intuitive choice for the local surrogate model of a control policy. Local surrogate methods seek to provide simple approximate models that are faithful to the original policy in a space around a target point. For control tasks, often only predictions of future behavior are useful. Trees provide compact surrogate models by only representing behavior in states that are forward-reachable from the current state. This is in contrast to methods that define local neighborhoods by arbitrarily perturbing the initial point ribeiro2016. In these cases, explanatory capacity is wasted on unnecessary states.

Tree Build Algorithm

Figure 3: Tree build example. (a) An initial set of particles is sampled from the given initial state or belief. These particles are advanced in the simulation using the action at the root node . Particles reaching terminals states are marked with . (b) Non-terminal particles are clustered into new action nodes. (c) Action nodes with a sufficient number of particles are advanced another time step. This process will continue for each action node until all particles terminate or reach a maximum depth.

Trees are built by simulating multiple executions of the baseline policy from the given initial state or belief and clustering them into action nodes. The process is illustrated in fig. 3. Each simulation is represented by a collection of particles for each time along the simulation trajectory. Particles are tuples of the state at time and reward received from to . For POMDPs, particles also record the observation and belief .

Tree construction begins by first generating a set of particles for the initial state or belief. For the initial step, all particles are assigned to the root action node , which takes the action given by the baseline policy for MDPs or for POMDPs. The particles are then advanced through one simulation time step to produce the particles for , as shown in fig. 2(a). The particles that did not enter terminal states are clustered to new action nodes as shown in fig. 2(b). This process continues until a terminal state is encountered or a fixed depth limit is met. Action nodes that do not have at least particles are not expanded further to limit over-fitting, where is specified by the user.

The point of entry for tree construction is the Build function, presented in algorithm 1 for MDPs. The initial particle set is created and clustered into the root action node. Each root particle for an MDP policy is initialized with the same state . Each action node is a tuple , where is the node’s set of particles, is the action taken from that node, and is the set of child nodes. The Build function for POMDPs follows the same procedure as for MDPs, though states and observations for the initial particle set are sampled from the initial belief .

1:procedure BuildMDP(, , , )
5:     for  
11:     return root
Algorithm 1 Build MDP

Particles are advanced through recursive calls to the Rollout function (algorithm 2). Each time Rollout is called from an action node , it proceeds if the size of the node particle set exceeds the minimum threshold and the depth limit has not been exceeded. If these conditions are met, each particle is advanced by calling the simulator using that particle’s state and the node action. The set of new particles are then grouped into new action nodes by the Cluster function. Rollout is recursively called on each new action node and the returned node is added to the child node set .

1:procedure Rollout(, , )
2:     if  and  
4:         for  
7:              if not done 
11:         for  
14:     return node
Algorithm 2 Rollout

The action nodes of the tree encode a deterministic policy for the sampled particles, , where is the set of states contained in the particle set. The particles are clustered into action nodes to minimize how much this policy deviates from the baseline policy while meeting constraints on tree size. To achieve this, we developed a clustering algorithm that approximately solves for the optimal clustering through recursive greedy optimization.

The recursive clustering approach is presented in algorithm 3. The algorithm progressively clusters particles into an increasing number of nodes until a distance measure between the actions assigned by the tree and the baseline policy is less than a threshold or the maximum number of nodes is exceeded.

1:procedure Cluster(, )
3:     repeat
6:     until  or
7:     return
Algorithm 3 Recursive Cluster

Clusters are assigned according to a greedy heuristic process shown in 

algorithm 4. In this approach, the complete set of actions assigned to each particle under the baseline policy is ranked according to frequency by the UniqueActions function. Action nodes for each of the top actions and all particles assigned that action by the baseline policy are clustered. Any particles not clustered to a node at the end of this process are assigned to the previously formed node that minimizes the distance between that action and the action assigned by the baseline policy.

1:procedure GreedyCluster(, , )
5:     for  
10:     for  
14:     return ,
Algorithm 4 Greedy Cluster

The distance metric may be any appropriate measure assigned by the user for the given policy type. For action value function policies, an effective distance measure for an action would be


which intuitively gives the predicted sub-optimality of the action. For stochastic policies which output distributions over actions, the difference in action probability would be an appropriate distance.

Tree Policy Control

To use the tree as a policy during task execution, it must be able to generalize to states not encountered in the particle set of the tree. To do this, we propose a simple method that uses the baseline policy, constrained by the tree at each time step. For a tree constructed for a state , the policy will always take the action at the root node . For all remaining steps, the policy will return


where is the set of actions in the child set of the preceding action node . Intuitively, the agent will take the best action predicted by the baseline policy that is included in the tree. Building a new tree each time a leaf state is encountered allows the tree policies to be run in-the-loop for long or infinite-horizon problems.


We ran experiments to test the performance of the proposed approach. One set of experiments are conducted on a simple grid world task. These experiments were designed to quantitatively measure the effect of various algorithm parameters and environment features on tree size and fidelity. Two additional experiments were run on more complex tasks that demonstrate the utility of the approach on real-world motivating examples. The first of these is multi-city vaccine deployment planning. In large-scale infrastructure planning tasks such as this, it is often necessary to have a reasonable prediction of all steps of the plan in order to gain stakeholder trust. The second task is a cyber security agent that provides recommendations to a human analyst to secure a network against attack. Interpretable policies are important for tasks with human oversight and cooperation.

We implemented the proposed algorithm with the distance function defined in eq. 4

in Python. Neural network training was done in PyTorch 

paszke2015. Source code, full experiment descriptions, and results are available in the Appendix.

Grid World

In the grid world task, an agent must navigate a discrete, 2D world to reach a goal state while avoiding trap states. The agent may take one of total actions to move between and units in any of the four cardinal directions on the grid. The agent moves in the intended direction with probability and takes a random action otherwise. To incentivize solving the problem quickly, the agent receives a cost of at each time step. The agent gets a positive reward of for reaching a goal state and a penalty of for reaching a trap state. The episode terminates when the agent reaches a goal state or when a maximum number of steps is reached. Because we know exact transition probabilities, we used discrete value iteration to learn a baseline policy sutton1998.

We constructed surrogate trees from the baseline policy using various environment and tree-build algorithm settings. For the environment, we swept over different values of the transition probability and the number of actions . For the tree build algorithm, we varied the total number of particles, the minimum particle count, the distance threshold , and the maximum leaf node depth. Only one parameter was varied for each test, with all others held fixed. The baseline settings for the environment were actions and a successful transition probability of . The baseline tree build parameters are 1000 total particles, 250 minimum particles, distance threshold of 0.01, and maximum depth of 10.

To test each tree, we ran it as a policy from the initial state until it reached a leaf node. The baseline policy was then used to complete the episode. We tested 2,500 trees for each parameter configuration. Select results are shown table 1

. The mean change in performance of the tree policy relative to the baseline is shown along with one standard error bounds. The average depth of leaf nodes is also shown, though standard error is omitted as each had


From the transition probability sweep, we see that relative performance generally improves as transition probability increases At , both policies perform very poorly and the difference between the two is not significant as a result. In the deterministic case, the surrogate tree perfectly represents the baseline policy. These results suggest that surrogate trees are better suited for tasks with low stochasticity.

Parameter Value Rel Change (%) Leaf Depth
Trans. Prob. 0.5
No. Actions 4
Max Depth 3
Table 1: Parameter Search Results. The change in task performance relative to the baseline policy performance and the average depth of leaf nodes in the trees are given. Values generated with the baseline settings are shown in bold.

As we increase the number of actions, the difference in performance between the baseline and the tree policy decreases. The average leaf depth of the trees also decreases with more actions. This likely explains the improved performance, as shallower trees will transition back to the baseline policy sooner in the tests. Another possible explanation is that as the number of actions grows, the cost of taking a sub-optimal action decreases. For example, with actions, if the optimal action is not taken, then the agent moves in a completely different direction than it should. With actions, the agent may still move in the correct direction, though by more or less distance than optimal.

For the tree building algorithm parameters, we see that as increases, the depth of the leaf nodes also decreases. This is likely because higher values of leads to more aggressive node clustering and fewer branches. Similar to the trend observed by varying the number of actions, as the average leaf depth decreases, relative performance increases. Similarly, as the max depth increases, the tree is allowed to grow deeper and the relative performance drops.

Vaccine Planning

In the vaccine deployment task, an agent decides the order in which to start vaccine distributions in cities in the midst of a pandemic outbreak. Each step, the agent picks a city in which to start a vaccine program. The spread of the disease is modeled as a stochastic SIRD model bailey1975, with portions of each city population being susceptible to infection (S), infected (I), recovered (R), or dead (D). The cities are modeled as a fully connected, weighted graph, where weight encodes the amount of traffic between city and . The state of city changes from time to as


where , , and are the mean infection, recovery, and death rates, respectively. The vaccination rate of city at time and is equal to zero until an action is taken to deploy a vaccination program to that city. The noise is sampled from a zero mean Gaussian. The effective infection exposure at city is defined as , where the sum is taken over all cities, and . Cities with closer index values will have higher weights, for example .

Each episode is initialized with up to 10% of each city infected and the remainder susceptible. One city is initialized with 25% infected. The episode concludes after five cities have had vaccine programs started. The simulation is then run until all cities have zero susceptible or infected population. The reward at each time step is , where and are parameters.

We trained a neural network policy using double DQN hasselt2010 with -step returns. The trained policy achieved an average score with one standard error bounds of over 100 trial episodes. We built trees for 100 random initial states with 2000 particles and a minimum particle threshold of 250 and tested their performance as forward policies. The trees achieved an average score of over the 100 trials, for an average performance drop of . To compare to a baseline surrogate modeling approach, we also trained and tested a LIME model ribeiro2016 with 2000 samples. The LIME average score was , for an average performance loss of .

The surrogate tree in fig. 2 provides an intuitive understanding of the learned policy. In fig. 1(b) we can see that the policy does not deploy vaccines to the most heavily infected cities first. It instead prioritizes cities with larger susceptible populations to give the vaccine time to take effect on a larger amount of the population.

Cyber Security

Figure 4: Cyber Security Policy Tree. The figure shows the first three layers of a surrogate tree for the cyber security task. The most frequent observation from the “Clean Host” action node set is shown.

The cyber security task requires an agent to prevent unauthorized access to secure data server on a computer network. The computer network is comprised of four local area networks (LANs), each of which has a local application server and ten workstations, and a single secure data server. The compromise state of the network is not known may be observed through noisy alerts generated from malware scans. Workstations are networked to all others on their LAN and to the LAN application server. Servers are randomly connected in a complete graph. An attacker begins with a single workstation compromised and takes actions to compromise additional workstations and servers to reach the data server.

The defender can scan all nodes on a LAN to locate compromised nodes with probability . Compromised nodes will also generate alerts without being scanned with low probability. The defender can also scan and clean individual nodes to detect and remove compromise. The reward is zero unless the data server is compromised, in which case a large penalty is incurred. The defender was trained using Rainbow DQN hessel2018.

Automated systems such as this are often implemented with a human in the loop. Policies that can be more easily interpreted are more likely to be trusted by a human operator. A surrogate tree for the neural network is shown in fig. 4. Unlike the baseline neural network, the policy encoded by this tree can be easily interpreted. The agent will continually scan LAN 1 in most cases, and will only clean a workstation after malware has been detected.


In this work, we presented methods to construct local surrogate policy trees from arbitrary control policies. The trees are more interpretable than high-dimensional policies such as neural networks and provide quantitative estimates of future behavior. Our experiments show that, despite truncating the set of actions that may be taken at each future time step, the trees retain fidelity with their baseline policies. Experiments demonstrate the effect of various environment and algorithm parameters on tree size and fidelity in a simple grid world. Demonstrations show how surrogate trees may be used in more complex, real-world scenarios.

The action node clustering presented in this work used a heuristic search method that provided good results, but without any optimality guarantees. Future work will look at improved approaches to clustering, for example by using a mixed integer program optimization. We will also explore using the scenarios simulated to construct the tree to backup more accurate value estimates, and refine the resulting policy. Including empirical backups such as these may also allow calculation of confidence intervals or bounds on policy performance