In many applications of computer vision (e.g., medical imaging and autonomous driving) insight into the prediction process or justification for a decision is critical. Whereas deep learning techniques achieve state-of-the-art accuracy in these settings, they provide little insight into the resulting predictions.
In an attempt to address the loss of interpretability in deep learning, a growing body of work explores explainable predictions. Explainable computer vision often features saliency maps, which depict the parts of an image that significantly impact the final classification. However, visual explanations focus on input to the model rather than the model itself. Therefore these techniques provide little insight into the general behavior of the model and how it might perform on data outside the domain of the available training examples.
In this work we propose neural-backed decision trees (NBDTs) to make state-of-the-art computer vision models interpretable. These NBDTs require no special architectures: Any neural network for image classification can be transformed into an NBDT just by fine-tuning with a custom loss. Furthermore, NBDTs perform inference by breaking image classification into a sequence of intermediate decisions. This sequence of decisions can then be mapped to more interpretable concepts and reveal perceptually informative hierarchical structure in the underlying classes. Critically, in contrast to prior work on decision trees in computer vision, NBDTs are competitive with state-of-the-art results on CIFAR10 , CIFAR100 , TinyImageNet, and ImageNet  and are substantially (up to ) more accurate than comparable decision tree based approaches while also being more interpretable.
We introduce a two stage training procedure for NBDTs. First, we compute a hierarchy, dubbed the induced hierarchy (Fig. 1, Step 1). This hierarchy is derived from the weights of a neural network already trained on the target dataset. Second, we fine-tune the network with a custom loss, designed specifically for this tree, called the tree supervision loss (Fig. 1, Step 2). This loss forces the model to maximize decision tree accuracy given a fixed tree hierarchy. We then run inference in two steps: (1) We construct features for each training image using the network backbone (Fig. 1, Step 3). Then for each node, we compute a vector in the network’s weight space that best represents the leaves in its subtree, given the decision tree hierarchy – we refer to this vector as the representative vector. (2) Starting at the root node, each sample is sent to the child with the most similar representative vector to the sample. We continue picking and traversing the tree until we reach a leaf. The class associated with this leaf is our prediction (Fig. 1, Step 4). This contrasts related work that introduces obstacles for interpretability such as impure leaves  or an ensemble of models [1, 16].
We summarize our contributions as follows:
We propose a method for running any classification neural network as a decision tree by defining a set of embedded decision rules that can be constructed from the fully-connected layer. We also design induced hierarchies that are easier for neural networks to learn.
We propose tree supervision loss, which boosts neural network accuracy by 0.5% and produces high-accuracy NBDTs. We demonstrate that our NBDTs achieve accuracies comparable to neural networks on small, medium, and large-scale image classification datasets.
We present qualitative and quantitative evidence of semantic interpretations for our model decisions.
2 Related Works
Until the more recent success of deep learning, decision trees defined the state-of-the-art in both accuracy and model interpretability on a wide variety of learning tasks. As deep learning based techniques began to dominate other learning methods, there was a significant amount of work exploring the intersection of deep learning and decision trees. We summarize the co-evolution of deep learning and decision trees.
The study of decision tree and neural network combinations dates back three decades, where neural networks were seeded with weights provided by decision trees [4, 5, 15]. Work converting from decision trees to neural networks also dates back three decades [14, 17, 6, 7]. Like distillation , these works focused on treating the neural network as an oracle and querying it to create splits.
Decision trees to neural networks. Recent work also seeds neural networks with weights provided by decision trees , with renewed interest in gradient-based approaches . These methods show empirical evidence on UCI datasets, in very feature-sparse and sample-sparse regime.
Neural networks to decision trees. Recent work  uses distillation, training a decision tree to mimic a neural network’s input-output function. All of these works evaluate on simple datasets such as UCI or MNIST, while our method is evaluated on more complex datasets such as CIFAR10, CIFAR100, and TinyImageNet.
Combining neural networks with decision trees. More recent are works combining neural networks with decision trees, scaling up inference to datasets with many high-dimensional samples. The Deep Neural Decision Forest  saw performance that matched neural networks on ImageNet. However, this occurred before the inception of residual networks and sacrificed interpretability of the model, by using impure leaves and requiring a forest. Murthy et al.  propose creating a new neural network for each node in the decision tree and showed interpretable outputs. Ahmed et al.  (NofE) modify this by sharing the backbone between all nodes but support only depth-2 trees; NofE sees ImageNet performance competitive with pre-ResNet architectures. Our method further builds on this by sharing not just the backbone but also the fully-connected layer; we furthermore show competitive performance with state-of-the-art neural networks, including residual networks, while preserving interpretability.
Instead of combining neural networks and decision trees explicitly, several works borrow ideas from decision trees for neural networks and vice versa. In particular, several redesigned neural network architectures utilize decision tree branching structures [35, 21, 34]. Whereas accuracy improves, this approach sacrifices decision tree interpretability. Others use decision trees to analyze neural network weights [39, 24]. This has the opposite downfall, of either sacrificing accuracy or not supporting a mechanism for prediction. As we hypothesize and show, a high-accuracy decision tree is necessary to explain and interpret high-accuracy models. Furthermore, our competitive performance shows that accuracy and explainability do not need to be traded off.
Visual Explanations. An orthogonal but predominant explainability direction involves generating saliency maps, which highlight the spatial evidence used for decisions made by neural networks [30, 37, 28, 38, 27, 26, 25, 31]
. White-box techniques such as Guided Backpropagation, Deconvolution [37, 28], Grad-CAM  and Integrated Gradients  use the gradients of the network to determine the most salient regions in an image and Black-box techniques such as LIME  and RISE  determine pixel importance by perturbing the input and measuring the change in prediction. Saliency maps only explain a single image and are unhelpful when a network is looking at the right thing for the wrong reasons (e.g. a bird misclassified as an airplane). On the other hand, our method expresses the model’s prior over the entire dataset and explicitly breaks down each classification into a sequence of intermediate decisions.
In this section we describe the proposed steps, illustrated in Fig. 1, for converting any classification neural network into a decision tree: (1) Build an induced hierarchy (Sec. 3.2), (2) fine-tune the model with a tree supervision loss (Sec. 3.3). For inference, (3) featurize samples with the neural network backbone, and (4) run decision rules embedded in the fully-connected layer (Sec. 3.1).
As implied by steps 3 and 4, our neural-backed decision tree (NBDT), has the exact same architecture as a standard neural network. As explained below in more detail, a subset of a fully-connected layer represents a node in the decision tree. This means that our method (1) is broadly applicable, as all classification neural network architectures are supported as-is, i.e. no architectural modifications are needed to use NBDT and (2) benefits from deep learning know-how, as our NBDT accuracy improves as the neural network accuracy improves.
3.1 Inference with Embedded Decision Rules
First, our NBDT approach featurizes each sample using the neural network backbone; the backbone consists of all neural network layers before the final fully-connected layer. Second, at each node, we take the inner product between the featurized sample and each child node’s representative vector . Note that all representative vectors are computed from the neural network’s fully-connected layer weights. Thus, these decision rules are “embedded” in the neural network. Third, we use these inner-products to make either hard or soft decisions, described below.
To motivate why we use inner-products, we will first construct a degenerate decision tree that is equivalent to a fully-connected layer.
Fully-connected layer. The fully-connected layer’s weight matrix is . Running inference with a featurized sample is a matrix-vector product:
The matrix-vector product yields inner-products between and each , which is written as . The index of the largest inner product is our class prediction.
Decision Tree. Consider a minimal tree, with a root node and child nodes. Each child node is a leaf, and each child node has a representative vector, namely a row vector from . Running inference with a featurized sample means taking inner products between and each child node’s representative vector , which is written as . Like the fully-connected layer, the index of the largest product is our class prediction. This is illustrated in Fig. 2 (B. Naive).
Even though the two computations are represented differently, both predict the class by taking the index of the largest inner product . We refer to the decision tree inference as running embedded decision rules.
We next extend the naïve decision tree beyond the degenerate case. Our decision rule requires that each child node has a representative vector . As a result, if we add a non-leaf child to the root, this non-leaf child would need a representative vector. We naively consider the non-leaf’s representative vector to be the average of all the subtrees’ leaves’ representative vectors. With a more complex tree structure containing intermediate nodes, there are now two ways to run inference:
Hard Decision Tree. Compute an argmax at each node, over all children. For each node, take the child node corresponding to the largest inner product, and traverse that child node. This process selects one leaf (Fig. 2, A. Hard).
Soft Decision Tree. Compute a softmax at each node, over all children, to obtain probabilities of each child per node. For each leaf, take the probability of traversing that leaf from its parent. Then take the probability of traversing the leaf’s parent from its grandparent. Continue taking products until you reach the root. This product is the probability of that leaf and its path to the root. Tree traversal will yield one probability for each leaf. Compute an argmax over this leaf distribution, to select one leaf (Fig. 2, C. Soft).
This allows us to run any classification neural network as a sequence of embedded decision rules. However, simply running a standard-issue pretrained neural network in this way will result in poor accuracy. In the next section, we discuss how to maximize accuracy by fine-tuning the neural network to perform well after determining the hierarchy.
3.2 Building Induced Hierarchies
With the above inner-product decision rule, there are intuitively easier decision tree hierarchies for the network to learn. These easier hierarchies may more accurately reflect how the network attained high accuracy. To this end, we ran hierarchical agglomerative clustering on the class representatives extracted from the fully-connected layer weights . As described in the previous Sec. 3.1, each leaf is one (Fig. 3, Step B) and each intermediate node’s representative vector is the average of all the representatives of its subtrees’ leaves (Fig. 3, Step C). We refer to this hierarchy as the induced hierarchy (Fig. 3).
We additionally conduct experiments with an alternative WordNet-based hierarchy. WordNet  provides an exising hierarchy of nouns, which we leverage to relate the classes in each dataset, linguistically. We find a minimal subset of the WordNet hierarchy that includes all classes as leaves, pruning redundant leaves and single-child intermediate nodes. As a result, WordNet relations provide “free” and interpretable labels for this candidate decision tree, e.g. classifying a Cat also as a Mammal and a Living Thing. To leverage this “free” source of labels, we automatically generate hypotheses for each intermediate node in an induced hierarchy, by finding the earliest ancestor of each subtrees’ leaves.
3.3 Training with Tree Supervision Loss
All of the proposed decision trees above suffer from one major issue: Even though the original neural network is encouraged to separate representative vectors for each class, it is not trained to separate representative vectors for each internal node. This is illustrated in Fig. 4. To amend this issue, we add loss terms that encourage the neural network to separate representatives for internal nodes, during training. We now explain the additional loss terms for the hard and soft decision rules in turn (Fig. 5).
For hard decision rules, we use the hard tree supervision loss. The original neural network’s loss minimize cross entropy across the classes. For a -class dataset, this is a -way cross entropy loss. Each internal node’s goal is similar: minimize cross-entropy loss across the child nodes. For node with children, this is a -way cross entropy loss between predicted probabilities and labels . We refer to this collection of new loss terms as the hard tree supervision loss (Eq. 2). The individual cross entropy losses for each node are scaled so that the original cross entropy loss and the tree supervision loss are weighted equally, by default. We test various weighting schemes in Sec. 4.2. If we assume nodes in the tree, excluding leaves, then we would have different cross entropy loss terms – the original cross entropy loss and hard tree supervision loss terms. This is , where:
For soft decision rules, we use the soft tree supervision loss. In Sec 3.1, we described how the soft decision tree provides a single distribution over leaves, . We add a cross entropy loss over this distribution. In total, there are 2 different cross entropy loss terms – the original cross entropy loss and the soft tree supervision loss term. This is , where:
Supplementary materials include details of the mathematical formulation.
Our experiments obtain state-of-the-art results for decision trees on a number of image classification benchmark datasets. We report results on a variety of different scenarios across models, datasets and inference modes:
Inference modes: Soft vs. hard inference.
We also perform ablation studies on tree supervision loss weight and different hierarchies, noting that a tree supervision loss with half weight consistently boosts the base neural network’s accuracy by 0.5% across datasets.
Our decision trees achieve 97.57% on CIFAR10, 82.87% on CIFAR100, and 66.66% on TinyImageNet, preserving accuracy of recently state-of-the-art neural networks. On CIFAR10, our soft decision tree matches WideResnet28x10, with a 0.05% margin. On CIFAR100, our soft decision tree achieves accuracy 0.57% higher than WideResnet28x10’s, outperforming the highest competing decision-tree-based method (NofE) by 6.63%. On TinyImageNet, our soft decision tree achieves accuracy within 1% of WideResNet’s. Furthermore, the ResNet18 variant outperforms DNDF by 18.2%.
On ImageNet, NBDTs obtain 75.30% top-1 accuracy, outperforming the strongest competitor NofE by 14%. Note that we take the best competing results for any decision-tree-based method, but the strongest competitors hinder interpretability by using ensembles of models like a decision forest (DNDF, DCDJ) or feature shallow trees with only depth 2 (NofE).
4.2 Ablation Studies
Tree Supervision Loss. The tree supervision loss, as described in Sec. 3.3, boosts the accuracy of a neural network by 0.5% with tree supervision loss weight of 0.5, when training from scratch on CIFAR100 and TinyImageNet (Table 2).
WordNet Hierarchy. Table 3
shows that WordNet is comparable to induced hierarchies on CIFAR but underperforms by 4.17% on TinyImageNet. This is because WordNet similarity is not indicative of visual similarity: For example, by virtue of being an animal,Bird is closer to Cat than to Plane, according to WordNet. However, the opposite is true for visual similarity: by virtue of being in the sky, Bird is more visually similar to Plane than to Cat.
Tree Supervision Loss Weight. As we vary the coefficient for the tree supervision loss, we note that disproportionately assigning weight to the tree supervision loss (by two orders of magnitude) significantly degrades the performance of both the neural network and the NBDT. However, our method is robust to imbalance between the two loss terms of up to an order of magnitude. We conclude the method is not hyper-sensitive to the loss coefficient (Table 4).
The explainability of a decision tree is well-established, as the final prediction can be broken into a sequence of decisions that can be evaluated independently for correctness. In the case where the input features are easily understood (e.g. medical/financial data), analyzing the rules which dictate the splits in the tree is relatively straightforward, but when the input is more complex like an image, this becomes more challenging. In this section, we perform qualitative and quantitative analysis of decisions made by intermediate tree nodes.
5.1 Explainability of Nodes’ Semantic Meanings
Since the induced hierarchy is constructed using model weights, it is not forced to split on particular attributes. While hierarchies like WordNet provide hypotheses for a node’s meaning, Fig.6 shows that WordNet doesn’t suffice, as the tree may split on contextual attributes such as underwater and on land. To diagnose node meanings, we perform the following 4-step test:
Make a hypothesis for the node’s meaning (e.g. Animal vs. Vehicle). This hypothesis can be computed automatically from a given taxonomy like WordNet or deduced from manual inspection of leaves for each child (Fig. 7).
Collect a dataset with new, unseen classes that test the hypothesised meaning of the node in step 1 (e.g. Elephant is an unseen Animal). Samples in this dataset are referred to as out-of-distribution samples, as they are drawn from a separate labeled dataset.
Pass samples from this dataset through the node in question. For each sample, check whether the selected child node agrees with the hypothesis.
The accuracy of the hypothesis is the percentage of samples passed to the correct child. If the accuracy is low, repeat with a different hypothesis.
This process automatically validates WordNet hypotheses, but manual intervention is needed for hypotheses beyond WordNet. Fig. 6(a) depicts the CIFAR10 tree induced by a WideResNet28x10 model trained on CIFAR10. Our hypothesis is that the root note splits on Animal vs. Vehicle. We collect out-of-distribution images for Animal and Vehicle classes that are unseen at training time, from CIFAR100. We then compute the hypothesis’ accuracy. Fig. 6(b) shows our hypothesis accurately predicts which child each unseen-class’s samples traverse.
5.2 Sidestepping the Accuracy-Interpretability Tradeoff
Induced hierarchies cluster vectors in weight space, but classes that are close in weight space may not have similar semantic meaning: Fig. 8 depicts the trees induced by WideResNet20x10 and ResNet10, respectively. While the WideResNet induced hierarchy (Fig. 7(a)) groups semantically-similar classes, the ResNet (Fig. 7(b)) induced hierarchy does not, grouping classes such as Frog, Cat, and Airplane. This disparity in semantic meaning is explained by WideResNet’s 4% higher accuracy: we believe that higher-accuracy models exhibit more semantically-sound weight spaces. Thus, unlike previous work, NBDTs feature better interpretability with higher accuracy, instead of sacrificing one for the other. Furthermore, the disparity in hierarchies indicates that low-accuracy, interpretable models do not provide insight into high-accuracy decisions; an interpretable, state-of-the-art model is needed to interpret state-of-the-art neural networks.
5.3 Visualization of Tree Traversal
To interpret not only tree hierarchies but also tree traversals, we visualize the percentage of samples that pass every node (Fig. 9). This highlights both the correct path (the most frequently traversed) and allows us to interpret common incorrect paths (Fig. 8(a)). To be specific, we can interpret attributes shared between the leaves of traversed nodes. These attributes may be Backgrounds or Scenes but could also be Color or Shape. Fig. 8(b) depicts the paths of samples that describe context. In this case, very few of the animals are recognized in a Seashore environment while Ship is almost always seen in that environment. Fig. 8(c) depicts the paths of samples belonging to an out-of-distribution class that does not fit the attributes of the hypothesized node but maintains path consistency. In this case, Teddy tends toward the animal classes, specifically Dog, because it shares similar shape and visual features.
In this work, we propose Neural-Backed Decision Trees, removing the dichotomy between accuracy and interpretability; narrowing the accuracy gap between neural networks and decision trees to 1% on CIFAR10, CIFAR100, TinyImageNet and to 2% on ImageNet; advancing state-of-the-art for interpretable methods by 14% on ImageNet to 75.30% top-1 accuracy. Any classification neural network can be converted into an NBDT: We design decision tree structures called induced hierarchies that neural networks then fine-tune on, using a tree supervision loss. The network is then run as a decision tree using embedded decision rules. As a fortuitous side effect, the tree supervision loss also boosts the original neural network accuracy by 0.5%. Furthermore, to assess semantic meaning, we automatically generate hypotheses for each node’s meaning using WordNet, then introduce a 4-step, human-in-the-loop algorithm that validates these hypotheses both qualitatively and quantitatively.
-  (2016-04) Network of experts for large-scale image categorization. Vol. 9911. Cited by: §1, §2, Table 1.
-  (2019) XOC: explainable observer-classifier for explainable binary decisions. CoRR abs/1902.01780. Cited by: Table 1.
-  (2017) Deep convolutional decision jungle for image classification. CoRR abs/1706.02003. Cited by: Table 1.
-  (1990) Initializing neural networks using decision trees. Cited by: §2.
-  (1994) Initializing neural networks using decision trees. In Proceedings of the International Workshop on Computational Learning and Natural Learning Systems, pp. 3–15. Cited by: §2.
-  (2000) Converting a trained neural network to a decision tree dectext - decision tree extractor. In ICMLA, Cited by: §2.
-  (2004-01) Decision tree extraction from trained neural networks. Cited by: §2.
-  (2009) ImageNet: A Large-Scale Hierarchical Image Database. In CVPR09, Cited by: §1, item 1.
UCI machine learning repository. University of California, Irvine, School of Information and Computer Sciences. External Links: Cited by: §2, §2.
-  (2017) Distilling a neural network into a soft decision tree. CoRR abs/1711.09784. Cited by: §2.
Deep residual learning for image recognition.
The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: item 2.
-  (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §2.
-  (2018-10) Deep neural network initialization with decision trees. IEEE Transactions on Neural Networks and Learning Systems PP, pp. 1–10. Cited by: §2.
-  (1995) Decision-tree based neural network (extended abstract). In Machine Learning: ECML-95, Berlin, Heidelberg, pp. 295–298. Cited by: §2.
-  (1995) Initialization of neural networks by means of decision trees. Knowledge-Based Systems 8 (6), pp. 333 – 344. Note: Knowledge-based neural networks Cited by: §2.
-  (2015-12) Deep neural decision forests. In The IEEE International Conference on Computer Vision (ICCV), Cited by: §1, §2, Table 1.
-  (1999) Extracting decision trees from trained neural networks. Pattern Recognition 32 (12), pp. 1999 – 2009. Cited by: §2.
-  (2009) Learning multiple layers of features from tiny images. Technical report . Cited by: §1, §2, item 1.
-  (2015) Tiny imagenet visual recognition challenge. Cited by: §1, §2, item 1.
-  (2010) MNIST handwritten digit database. ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist 2. Cited by: §2.
-  (2017) Deciding how to decide: dynamic routing in artificial neural networks. In ICML, Cited by: §2.
-  (1995-11) WordNet: a lexical database for english. Commun. ACM 38 (11), pp. 39–41. External Links: Cited by: §3.2.
-  (2016-06) Deep decision network for multi-class image classification. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2, Table 1.
-  (2018) Learning hierarchical visual representations in deep neural networks using hierarchical linguistic labels. CoRR abs/1805.07647. Cited by: §2.
-  (2018) RISE: randomized input sampling for explanation of black-box models. In Proceedings of the British Machine Vision Conference (BMVC), Cited by: §2.
-  (2016) ”Why should I trust you?”: explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, San Francisco, CA, USA, August 13-17, 2016, pp. 1135–1144. Cited by: §2.
-  (2017) Grad-cam: visual explanations from deep networks via gradient-based localization. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 618–626. Cited by: §2.
-  (2013) Deep inside convolutional networks: visualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034. Cited by: §2.
-  (2019) Transferring tree ensembles to neural networks. In Neural Information Processing, pp. 471–480. Cited by: §2.
-  (2014) Striving for simplicity: the all convolutional net. CoRR abs/1412.6806. Cited by: §2.
-  (2017) Axiomatic attribution for deep networks. International Conference on Machine Learning (ICML) 2017. Cited by: §2.
Efficientnet: rethinking model scaling for convolutional neural networks. arXiv preprint arXiv:1905.11946. Cited by: item 2.
-  (2019) Adaptive neural trees. Cited by: Table 1.
-  (2018-06) HydraNets: specialized dynamic architectures for efficient inference. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2.
-  (2018-09) Convolutional networks with adaptive inference graphs. In The European Conference on Computer Vision (ECCV), Cited by: §2.
-  (2016) Wide residual networks. arXiv preprint arXiv:1605.07146. Cited by: item 2.
-  (2014) Visualizing and understanding convolutional networks. In European Conference on Computer Vision (ECCV), pp. 818–833. Cited by: §2.
-  (2016) Top-down neural attention by excitation backprop. In European Conference on Computer Vision (ECCV), pp. 543–559. Cited by: §2.
-  (2019-06) Interpreting cnns via decision trees. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2.