NBDT: Neural-Backed Decision Trees

04/01/2020 ∙ by Alvin Wan, et al. ∙ Boston University 0

Deep learning is being adopted in settings where accurate and justifiable predictions are required, ranging from finance to medical imaging. While there has been recent work providing post-hoc explanations for model predictions, there has been relatively little work exploring more directly interpretable models that can match state-of-the-art accuracy. Historically, decision trees have been the gold standard in balancing interpretability and accuracy. However, recent attempts to combine decision trees with deep learning have resulted in models that (1) achieve accuracies far lower than that of modern neural networks (e.g. ResNet) even on small datasets (e.g. MNIST), and (2) require significantly different architectures, forcing practitioners pick between accuracy and interpretability. We forgo this dilemma by creating Neural-Backed Decision Trees (NBDTs) that (1) achieve neural network accuracy and (2) require no architectural changes to a neural network. NBDTs achieve accuracy within 1 TinyImageNet, using recently state-of-the-art WideResNet; and within 2 EfficientNet on ImageNet. This yields state-of-the-art explainable models on ImageNet, with NBDTs improving the baseline by  14 Furthermore, we show interpretability of our model's decisions both qualitatively and quantitatively via a semi-automatic process. Code and pretrained NBDTs can be found at https://github.com/alvinwan/neural-backed-decision-trees.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 13

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In many applications of computer vision (e.g., medical imaging and autonomous driving) insight into the prediction process or justification for a decision is critical. Whereas deep learning techniques achieve state-of-the-art accuracy in these settings, they provide little insight into the resulting predictions.

In an attempt to address the loss of interpretability in deep learning, a growing body of work explores explainable predictions. Explainable computer vision often features saliency maps, which depict the parts of an image that significantly impact the final classification. However, visual explanations focus on input to the model rather than the model itself. Therefore these techniques provide little insight into the general behavior of the model and how it might perform on data outside the domain of the available training examples.

In this work we propose neural-backed decision trees (NBDTs) to make state-of-the-art computer vision models interpretable. These NBDTs require no special architectures: Any neural network for image classification can be transformed into an NBDT just by fine-tuning with a custom loss. Furthermore, NBDTs perform inference by breaking image classification into a sequence of intermediate decisions. This sequence of decisions can then be mapped to more interpretable concepts and reveal perceptually informative hierarchical structure in the underlying classes. Critically, in contrast to prior work on decision trees in computer vision, NBDTs are competitive with state-of-the-art results on CIFAR10 [18], CIFAR100 [18], TinyImageNet[19], and ImageNet [8] and are substantially (up to ) more accurate than comparable decision tree based approaches while also being more interpretable.

We introduce a two stage training procedure for NBDTs. First, we compute a hierarchy, dubbed the induced hierarchy (Fig. 1, Step 1). This hierarchy is derived from the weights of a neural network already trained on the target dataset. Second, we fine-tune the network with a custom loss, designed specifically for this tree, called the tree supervision loss (Fig. 1, Step 2). This loss forces the model to maximize decision tree accuracy given a fixed tree hierarchy. We then run inference in two steps: (1) We construct features for each training image using the network backbone (Fig. 1, Step 3). Then for each node, we compute a vector in the network’s weight space that best represents the leaves in its subtree, given the decision tree hierarchy – we refer to this vector as the representative vector. (2) Starting at the root node, each sample is sent to the child with the most similar representative vector to the sample. We continue picking and traversing the tree until we reach a leaf. The class associated with this leaf is our prediction (Fig. 1, Step 4). This contrasts related work that introduces obstacles for interpretability such as impure leaves [16] or an ensemble of models [1, 16].

We summarize our contributions as follows:

  1. We propose a method for running any classification neural network as a decision tree by defining a set of embedded decision rules that can be constructed from the fully-connected layer. We also design induced hierarchies that are easier for neural networks to learn.

  2. We propose tree supervision loss, which boosts neural network accuracy by 0.5% and produces high-accuracy NBDTs. We demonstrate that our NBDTs achieve accuracies comparable to neural networks on small, medium, and large-scale image classification datasets.

  3. We present qualitative and quantitative evidence of semantic interpretations for our model decisions.

2 Related Works

Until the more recent success of deep learning, decision trees defined the state-of-the-art in both accuracy and model interpretability on a wide variety of learning tasks. As deep learning based techniques began to dominate other learning methods, there was a significant amount of work exploring the intersection of deep learning and decision trees. We summarize the co-evolution of deep learning and decision trees.

The study of decision tree and neural network combinations dates back three decades, where neural networks were seeded with weights provided by decision trees [4, 5, 15]. Work converting from decision trees to neural networks also dates back three decades [14, 17, 6, 7]. Like distillation [12], these works focused on treating the neural network as an oracle and querying it to create splits.

Decision trees to neural networks. Recent work also seeds neural networks with weights provided by decision trees [13], with renewed interest in gradient-based approaches [29]. These methods show empirical evidence on UCI datasets[9], in very feature-sparse and sample-sparse regime.

Neural networks to decision trees. Recent work [10] uses distillation, training a decision tree to mimic a neural network’s input-output function. All of these works evaluate on simple datasets such as UCI[9] or MNIST[20], while our method is evaluated on more complex datasets such as CIFAR10[18], CIFAR100[18], and TinyImageNet[19].

Combining neural networks with decision trees. More recent are works combining neural networks with decision trees, scaling up inference to datasets with many high-dimensional samples. The Deep Neural Decision Forest [16] saw performance that matched neural networks on ImageNet. However, this occurred before the inception of residual networks and sacrificed interpretability of the model, by using impure leaves and requiring a forest. Murthy et al. [23] propose creating a new neural network for each node in the decision tree and showed interpretable outputs. Ahmed et al. [1] (NofE) modify this by sharing the backbone between all nodes but support only depth-2 trees; NofE sees ImageNet performance competitive with pre-ResNet architectures. Our method further builds on this by sharing not just the backbone but also the fully-connected layer; we furthermore show competitive performance with state-of-the-art neural networks, including residual networks, while preserving interpretability.

Instead of combining neural networks and decision trees explicitly, several works borrow ideas from decision trees for neural networks and vice versa. In particular, several redesigned neural network architectures utilize decision tree branching structures [35, 21, 34]. Whereas accuracy improves, this approach sacrifices decision tree interpretability. Others use decision trees to analyze neural network weights [39, 24]. This has the opposite downfall, of either sacrificing accuracy or not supporting a mechanism for prediction. As we hypothesize and show, a high-accuracy decision tree is necessary to explain and interpret high-accuracy models. Furthermore, our competitive performance shows that accuracy and explainability do not need to be traded off.

Visual Explanations. An orthogonal but predominant explainability direction involves generating saliency maps, which highlight the spatial evidence used for decisions made by neural networks [30, 37, 28, 38, 27, 26, 25, 31]

. White-box techniques such as Guided Backpropagation

[30], Deconvolution [37, 28], Grad-CAM [27] and Integrated Gradients [31] use the gradients of the network to determine the most salient regions in an image and Black-box techniques such as LIME [26] and RISE [25] determine pixel importance by perturbing the input and measuring the change in prediction. Saliency maps only explain a single image and are unhelpful when a network is looking at the right thing for the wrong reasons (e.g. a bird misclassified as an airplane). On the other hand, our method expresses the model’s prior over the entire dataset and explicitly breaks down each classification into a sequence of intermediate decisions.

3 Method

In this section we describe the proposed steps, illustrated in Fig. 1, for converting any classification neural network into a decision tree: (1) Build an induced hierarchy (Sec. 3.2), (2) fine-tune the model with a tree supervision loss (Sec. 3.3). For inference, (3) featurize samples with the neural network backbone, and (4) run decision rules embedded in the fully-connected layer (Sec. 3.1).

As implied by steps 3 and 4, our neural-backed decision tree (NBDT), has the exact same architecture as a standard neural network. As explained below in more detail, a subset of a fully-connected layer represents a node in the decision tree. This means that our method (1) is broadly applicable, as all classification neural network architectures are supported as-is, i.e. no architectural modifications are needed to use NBDT and (2) benefits from deep learning know-how, as our NBDT accuracy improves as the neural network accuracy improves.

3.1 Inference with Embedded Decision Rules

First, our NBDT approach featurizes each sample using the neural network backbone; the backbone consists of all neural network layers before the final fully-connected layer. Second, at each node, we take the inner product between the featurized sample and each child node’s representative vector . Note that all representative vectors are computed from the neural network’s fully-connected layer weights. Thus, these decision rules are “embedded” in the neural network. Third, we use these inner-products to make either hard or soft decisions, described below.

To motivate why we use inner-products, we will first construct a degenerate decision tree that is equivalent to a fully-connected layer.

  1. Fully-connected layer. The fully-connected layer’s weight matrix is . Running inference with a featurized sample is a matrix-vector product:

    (1)

    The matrix-vector product yields inner-products between and each , which is written as . The index of the largest inner product is our class prediction.

  2. Decision Tree. Consider a minimal tree, with a root node and child nodes. Each child node is a leaf, and each child node has a representative vector, namely a row vector from . Running inference with a featurized sample means taking inner products between and each child node’s representative vector , which is written as . Like the fully-connected layer, the index of the largest product is our class prediction. This is illustrated in Fig. 2 (B. Naive).

Figure 2: Hard and Soft Decision Trees. Tree B is the naïve tree with one root and leaves. This is identical to a fully-connected layer: This tree takes inner products between the input and each leaf’s representative vector , then picks the leaf corresponding to the largest inner product. Tree A is the “hard” extension of the naïve tree. Each node picks the child node with the largest inner product, and visits that node next. Continue until a leaf is reached. Tree C

is the “soft” extension, where each node simply returns probabilities, as normalized inner products, of each child. For each leaf, compute the probability of its path to the root. Pick leaf with the highest probability.

Even though the two computations are represented differently, both predict the class by taking the index of the largest inner product . We refer to the decision tree inference as running embedded decision rules.

We next extend the naïve decision tree beyond the degenerate case. Our decision rule requires that each child node has a representative vector . As a result, if we add a non-leaf child to the root, this non-leaf child would need a representative vector. We naively consider the non-leaf’s representative vector to be the average of all the subtrees’ leaves’ representative vectors. With a more complex tree structure containing intermediate nodes, there are now two ways to run inference:

  1. Hard Decision Tree. Compute an argmax at each node, over all children. For each node, take the child node corresponding to the largest inner product, and traverse that child node. This process selects one leaf (Fig. 2, A. Hard).

  2. Soft Decision Tree. Compute a softmax at each node, over all children, to obtain probabilities of each child per node. For each leaf, take the probability of traversing that leaf from its parent. Then take the probability of traversing the leaf’s parent from its grandparent. Continue taking products until you reach the root. This product is the probability of that leaf and its path to the root. Tree traversal will yield one probability for each leaf. Compute an argmax over this leaf distribution, to select one leaf (Fig. 2, C. Soft).

This allows us to run any classification neural network as a sequence of embedded decision rules. However, simply running a standard-issue pretrained neural network in this way will result in poor accuracy. In the next section, we discuss how to maximize accuracy by fine-tuning the neural network to perform well after determining the hierarchy.

3.2 Building Induced Hierarchies

With the above inner-product decision rule, there are intuitively easier decision tree hierarchies for the network to learn. These easier hierarchies may more accurately reflect how the network attained high accuracy. To this end, we ran hierarchical agglomerative clustering on the class representatives extracted from the fully-connected layer weights . As described in the previous Sec. 3.1, each leaf is one (Fig. 3, Step B) and each intermediate node’s representative vector is the average of all the representatives of its subtrees’ leaves (Fig. 3, Step C). We refer to this hierarchy as the induced hierarchy (Fig. 3).

We additionally conduct experiments with an alternative WordNet-based hierarchy. WordNet [22] provides an exising hierarchy of nouns, which we leverage to relate the classes in each dataset, linguistically. We find a minimal subset of the WordNet hierarchy that includes all classes as leaves, pruning redundant leaves and single-child intermediate nodes. As a result, WordNet relations provide “free” and interpretable labels for this candidate decision tree, e.g. classifying a Cat also as a Mammal and a Living Thing. To leverage this “free” source of labels, we automatically generate hypotheses for each intermediate node in an induced hierarchy, by finding the earliest ancestor of each subtrees’ leaves.

Figure 3: Building Induced Hierarchies. Step A. Load the weights of pre-trained neural network’s final fully-connected layer, with weight matrix . Step B. Use each column of as representative vectors for each leaf node. For example, the red from A is assigned to the red leaf in B. Step C. Use the average of each pair of leaves for the parents’ representative vectors. For example, and (red and purple) in B are averaged to make (blue) in C. Step D. For each ancestor, take the subtree it is the root for. Average representative vectors for all leaves in the subtree. That average is the ancestor’s representative vector. In this figure, the ancestor is the root, so its representative vector is the average of all leaves .

3.3 Training with Tree Supervision Loss

Figure 4: Pathological Trees. In the plots, one cluster of points is marked with a green circle – the other with yellow. The center of each circle is given by the average of its two gray points. The corresponding decision tree is drawn to the right of each plot. A: Once given a point, the decision tree’s root will compute the child node with the closest representative vector (green or yellow point). Note that all samples for class 4 (red) will be closer to the wrong parent (yellow) than the right one (green). This is because A attempts to cluster 2 with 4 and 1 with 3. Thus, it will be difficult for the neural network to obtain high accuracy, as it needs to move all points drastically to separate the yellow and green points. B: With the same points, this tree clusters 1 with 2 and 3 with 4, leading to more separable clusters. Note that the decision boundary (dotted line) in B features a much larger margin, with respect to the green and yellow points. Thus, this tree is far easier for the neural network to classify points correctly with.

All of the proposed decision trees above suffer from one major issue: Even though the original neural network is encouraged to separate representative vectors for each class, it is not trained to separate representative vectors for each internal node. This is illustrated in Fig. 4. To amend this issue, we add loss terms that encourage the neural network to separate representatives for internal nodes, during training. We now explain the additional loss terms for the hard and soft decision rules in turn (Fig. 5).

Figure 5: Tree Supervision Loss has two variants: Hard Tree Supervision Loss (A) defines a cross entropy term per node. This is illustrated with the blue box for the blue node and the orange box for the orange node. The cross entropy is taken over the child node probabilities. The green node is the label leaf. The dotted nodes are not included in the path from the label to the root, so do not have a defined loss. Soft Tree Supervision Loss (B) defines a cross entropy loss over all leaf probabilities. The probability of the green leaf is the product of the probabilities leading up to the root (in this case,

). The probabilities for the other leaves are similarly defined. Each leaf probability is represented with a colored box. The cross entropy is then computed over this leaf probability distribution, represented by the colored box sitting directly adjacent to one another.

For hard decision rules, we use the hard tree supervision loss. The original neural network’s loss minimize cross entropy across the classes. For a -class dataset, this is a -way cross entropy loss. Each internal node’s goal is similar: minimize cross-entropy loss across the child nodes. For node with children, this is a -way cross entropy loss between predicted probabilities and labels . We refer to this collection of new loss terms as the hard tree supervision loss (Eq. 2). The individual cross entropy losses for each node are scaled so that the original cross entropy loss and the tree supervision loss are weighted equally, by default. We test various weighting schemes in Sec. 4.2. If we assume nodes in the tree, excluding leaves, then we would have different cross entropy loss terms – the original cross entropy loss and hard tree supervision loss terms. This is , where:

(2)

For soft decision rules, we use the soft tree supervision loss. In Sec 3.1, we described how the soft decision tree provides a single distribution over leaves, . We add a cross entropy loss over this distribution. In total, there are 2 different cross entropy loss terms – the original cross entropy loss and the soft tree supervision loss term. This is , where:

(3)

Supplementary materials include details of the mathematical formulation.

4 Experiments

Our experiments obtain state-of-the-art results for decision trees on a number of image classification benchmark datasets. We report results on a variety of different scenarios across models, datasets and inference modes:

  1. Datasets: CIFAR10[18], CIFAR100[18], TinyImageNet[19], ImageNet[8]

  2. Models: ResNet[11], recently state-of-the-art WideResNet[36], EfficientNet[32]

  3. Inference modes: Soft vs. hard inference.

We also perform ablation studies on tree supervision loss weight and different hierarchies, noting that a tree supervision loss with half weight consistently boosts the base neural network’s accuracy by 0.5% across datasets.

4.1 Results

Our decision trees achieve 97.57% on CIFAR10, 82.87% on CIFAR100, and 66.66% on TinyImageNet, preserving accuracy of recently state-of-the-art neural networks. On CIFAR10, our soft decision tree matches WideResnet28x10, with a 0.05% margin. On CIFAR100, our soft decision tree achieves accuracy 0.57% higher than WideResnet28x10’s, outperforming the highest competing decision-tree-based method (NofE) by 6.63%. On TinyImageNet, our soft decision tree achieves accuracy within 1% of WideResNet’s. Furthermore, the ResNet18 variant outperforms DNDF by 18.2%.

On ImageNet, NBDTs obtain 75.30% top-1 accuracy, outperforming the strongest competitor NofE by  14%. Note that we take the best competing results for any decision-tree-based method, but the strongest competitors hinder interpretability by using ensembles of models like a decision forest (DNDF, DCDJ) or feature shallow trees with only depth 2 (NofE).

Method Backbone CIFAR10 CIFAR100 TinyImageNet ImageNet
NN WideResnet28x10 97.62% 82.09% 67.65%
ANT-A* 93.28%
XOC 93.12% 60.77%
NofE ResNet56 76.24%
DDN 90.32% 68.35%
DCDJ NiN 69.0%
NBDT-H (Ours) WideResnet28x10 97.55% 82.21% 64.39%
NBDT-S (Ours) WideResnet28x10 97.57% 82.87% 66.66%
NN EfficientNet-ES 77.23%
NofE AlexNet 61.29%
NBDT-H (Ours) EfficientNet-ES 74.79%
NBDT-S (Ours) EfficientNet-ES 75.30%
NN ResNet18 94.97% 75.92% 64.13%
DNDF ResNet18 94.32% 67.18% 44.56%
NBDT-H (Ours) ResNet18 94.50% 74.29% 61.60%
NBDT-S (Ours) ResNet18 94.76% 74.92% 62.74%
Table 1: Results On all CIFAR10, CIFAR100, TinyImageNet, and ImageNet datasets, NBDT outperforms competing decision-tree-based methods, even uninterpretable variants such as a decision forest, by up to 18%. On CIFAR10, CIFAR100, and TinyImageNet, NBDTs largely stay within 1% of neural network performance. We italicize the neural network’s accuracy and bold the best-performing decision-tree-based accuracy. Our baselines are either taken directly from the original papers or improved using a modern backbone: Deep Neural Decision Forest (DNDF updated with ResNet18) [16], Explainable Observer-Classifier (XOC) [2], Deep Convolutional Decision Jungle (DCDJ) [3], Network of Experts (NofE) [1], Deep Decision Network (DDN) [23], and Adaptive Neural Trees (ANT) [33].

4.2 Ablation Studies

Tree Supervision Loss. The tree supervision loss, as described in Sec. 3.3, boosts the accuracy of a neural network by 0.5% with tree supervision loss weight of 0.5, when training from scratch on CIFAR100 and TinyImageNet (Table 2).

Dataset Backbone NN NN+TSL
CIFAR100 WideResnet28x10 82.09% 82.63% +0.59%
CIFAR100 ResNet18 75.92% 76.20% +0.28%
CIFAR100 ResNet10 73.36% 73.98% +0.62%
TinyImageNet ResNet18 64.13% 64.61% +0.48%
TinyImageNet ResNet10 61.01% 61.35% +0.34%
Table 2: Tree Supervision Loss. The original neural network’s accuracy increases by  0.5% for CIFAR100 and TinyImageNet across a number of models, after training with soft tree supervision loss.

WordNet Hierarchy. Table 3

shows that WordNet is comparable to induced hierarchies on CIFAR but underperforms by 4.17% on TinyImageNet. This is because WordNet similarity is not indicative of visual similarity: For example, by virtue of being an animal,

Bird is closer to Cat than to Plane, according to WordNet. However, the opposite is true for visual similarity: by virtue of being in the sky, Bird is more visually similar to Plane than to Cat.

Dataset Backbone Original WordNet Induced
CIFAR10 ResNet10 93.61% 93.65% 93.32%
CIFAR100 ResNet10 73.36% 71.79% 71.70%
TinyImageNet ResNet10 61.01% 52.33% 56.50%
Table 3: WordNet Hierarchy. We compare the WordNet hierarchy with the induced hierarchy. All results use a ResNet10 backbone with tree supervision loss weight of 10. Both inference and tree supervision losses are hard.

Tree Supervision Loss Weight. As we vary the coefficient for the tree supervision loss, we note that disproportionately assigning weight to the tree supervision loss (by two orders of magnitude) significantly degrades the performance of both the neural network and the NBDT. However, our method is robust to imbalance between the two loss terms of up to an order of magnitude. We conclude the method is not hyper-sensitive to the loss coefficient (Table 4).

Dataset Method
CIFAR10 ResNet18 94.97% 94.91% 94.44% 93.82% 91.91%
CIFAR10 NBDT-H 94.50% 94.06% 93.94% 92.28 %
CIFAR100 ResNet18 75.92% 76.20% 75.78% 75.63% 73.86%
CIFAR100 NBDT-H 66.84% 69.49% 73.23% 72.05%
TinyImageNet ResNet18 64.13% 64.61% 63.90% 63.98% 63.11%
TinyImageNet NBDT-H 43.05% 58.25% 56.25% 58.89%
Table 4: Tree Supervision Loss Weight. Below, refers to the coefficient for the hard tree supervision loss. All NBDT-H trees use the ResNet18 backbone with hard inference. Note that is simply the original neural network.

5 Explainability

The explainability of a decision tree is well-established, as the final prediction can be broken into a sequence of decisions that can be evaluated independently for correctness. In the case where the input features are easily understood (e.g. medical/financial data), analyzing the rules which dictate the splits in the tree is relatively straightforward, but when the input is more complex like an image, this becomes more challenging. In this section, we perform qualitative and quantitative analysis of decisions made by intermediate tree nodes.

5.1 Explainability of Nodes’ Semantic Meanings

(a) WordNet Tree
(b) Induced Tree
Figure 6: Tree Visualization of 10 classes from TinyImageNet using (a) the WordNet hierarchy and (b) the induced tree from a trained ResNet10 model.

Since the induced hierarchy is constructed using model weights, it is not forced to split on particular attributes. While hierarchies like WordNet provide hypotheses for a node’s meaning, Fig.6 shows that WordNet doesn’t suffice, as the tree may split on contextual attributes such as underwater and on land. To diagnose node meanings, we perform the following 4-step test:

  1. Make a hypothesis for the node’s meaning (e.g. Animal vs. Vehicle). This hypothesis can be computed automatically from a given taxonomy like WordNet or deduced from manual inspection of leaves for each child (Fig. 7).

  2. Collect a dataset with new, unseen classes that test the hypothesised meaning of the node in step 1 (e.g. Elephant is an unseen Animal). Samples in this dataset are referred to as out-of-distribution samples, as they are drawn from a separate labeled dataset.

  3. Pass samples from this dataset through the node in question. For each sample, check whether the selected child node agrees with the hypothesis.

  4. The accuracy of the hypothesis is the percentage of samples passed to the correct child. If the accuracy is low, repeat with a different hypothesis.

(a)
(b)
Figure 7: A Node’s semantic meaning. (a) CIFAR10 Tree Visualization of a WideResNet28x10 model. (b) Classifications of the hypothesized Animal/Vehicle node on samples of unseen classes of Vehicles (top) and Animals (bottom).

This process automatically validates WordNet hypotheses, but manual intervention is needed for hypotheses beyond WordNet. Fig. 6(a) depicts the CIFAR10 tree induced by a WideResNet28x10 model trained on CIFAR10. Our hypothesis is that the root note splits on Animal vs. Vehicle. We collect out-of-distribution images for Animal and Vehicle classes that are unseen at training time, from CIFAR100. We then compute the hypothesis’ accuracy. Fig. 6(b) shows our hypothesis accurately predicts which child each unseen-class’s samples traverse.

5.2 Sidestepping the Accuracy-Interpretability Tradeoff

Induced hierarchies cluster vectors in weight space, but classes that are close in weight space may not have similar semantic meaning: Fig. 8 depicts the trees induced by WideResNet20x10 and ResNet10, respectively. While the WideResNet induced hierarchy (Fig. 7(a)) groups semantically-similar classes, the ResNet (Fig. 7(b)) induced hierarchy does not, grouping classes such as Frog, Cat, and Airplane. This disparity in semantic meaning is explained by WideResNet’s 4% higher accuracy: we believe that higher-accuracy models exhibit more semantically-sound weight spaces. Thus, unlike previous work, NBDTs feature better interpretability with higher accuracy, instead of sacrificing one for the other. Furthermore, the disparity in hierarchies indicates that low-accuracy, interpretable models do not provide insight into high-accuracy decisions; an interpretable, state-of-the-art model is needed to interpret state-of-the-art neural networks.

(a) WideResNet28x10
(b) ResNet10
Figure 8: CIFAR10 induced hierarchies for (a) WideResNet (97.62% acc) and (b) ResNet (93.64% acc), with automatically-generated WordNet hypotheses for node meanings.

5.3 Visualization of Tree Traversal

To interpret not only tree hierarchies but also tree traversals, we visualize the percentage of samples that pass every node (Fig. 9). This highlights both the correct path (the most frequently traversed) and allows us to interpret common incorrect paths (Fig. 8(a)). To be specific, we can interpret attributes shared between the leaves of traversed nodes. These attributes may be Backgrounds or Scenes but could also be Color or Shape. Fig. 8(b) depicts the paths of samples that describe context. In this case, very few of the animals are recognized in a Seashore environment while Ship is almost always seen in that environment. Fig. 8(c) depicts the paths of samples belonging to an out-of-distribution class that does not fit the attributes of the hypothesized node but maintains path consistency. In this case, Teddy tends toward the animal classes, specifically Dog, because it shares similar shape and visual features.

(a)
(b)
(c)
Figure 9: Visualization of path traversal frequency for three different classes. (a) In-Distribution Class: Horse uses samples of a class found in training. Hypothesized meanings of intermediate nodes are from WordNet. (b) Context Class: Seashore uses samples unseen at training time, indicating reliance on context. (c) Confusing Class: Teddy uses samples that identify edge cases in node meanings.

6 Conclusions

In this work, we propose Neural-Backed Decision Trees, removing the dichotomy between accuracy and interpretability; narrowing the accuracy gap between neural networks and decision trees to 1% on CIFAR10, CIFAR100, TinyImageNet and to 2% on ImageNet; advancing state-of-the-art for interpretable methods by 14% on ImageNet to 75.30% top-1 accuracy. Any classification neural network can be converted into an NBDT: We design decision tree structures called induced hierarchies that neural networks then fine-tune on, using a tree supervision loss. The network is then run as a decision tree using embedded decision rules. As a fortuitous side effect, the tree supervision loss also boosts the original neural network accuracy by 0.5%. Furthermore, to assess semantic meaning, we automatically generate hypotheses for each node’s meaning using WordNet, then introduce a 4-step, human-in-the-loop algorithm that validates these hypotheses both qualitatively and quantitatively.

References

  • [1] K. Ahmed, M. Baig, and L. Torresani (2016-04) Network of experts for large-scale image categorization. Vol. 9911. Cited by: §1, §2, Table 1.
  • [2] S. Alaniz and Z. Akata (2019) XOC: explainable observer-classifier for explainable binary decisions. CoRR abs/1902.01780. Cited by: Table 1.
  • [3] S. Baek, K. I. Kim, and T. Kim (2017) Deep convolutional decision jungle for image classification. CoRR abs/1706.02003. Cited by: Table 1.
  • [4] A. Banerjee (1990) Initializing neural networks using decision trees. Cited by: §2.
  • [5] A. Banerjee (1994) Initializing neural networks using decision trees. In Proceedings of the International Workshop on Computational Learning and Natural Learning Systems, pp. 3–15. Cited by: §2.
  • [6] O. Boz (2000) Converting a trained neural network to a decision tree dectext - decision tree extractor. In ICMLA, Cited by: §2.
  • [7] D. Dancey, D. McLean, and Z. Bandar (2004-01) Decision tree extraction from trained neural networks. Cited by: §2.
  • [8] J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, and L. Fei-Fei (2009) ImageNet: A Large-Scale Hierarchical Image Database. In CVPR09, Cited by: §1, item 1.
  • [9] D. Dua and C. Graff (2017)

    UCI machine learning repository

    .
    University of California, Irvine, School of Information and Computer Sciences. External Links: Link Cited by: §2, §2.
  • [10] N. Frosst and G. E. Hinton (2017) Distilling a neural network into a soft decision tree. CoRR abs/1711.09784. Cited by: §2.
  • [11] K. He, X. Zhang, S. Ren, and J. Sun (2016-06) Deep residual learning for image recognition. In

    The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

    ,
    Cited by: item 2.
  • [12] G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §2.
  • [13] K. Humbird, L. Peterson, and R. McClarren (2018-10) Deep neural network initialization with decision trees. IEEE Transactions on Neural Networks and Learning Systems PP, pp. 1–10. Cited by: §2.
  • [14] I. Ivanova and M. Kubat (1995) Decision-tree based neural network (extended abstract). In Machine Learning: ECML-95, Berlin, Heidelberg, pp. 295–298. Cited by: §2.
  • [15] I. Ivanova and M. Kubat (1995) Initialization of neural networks by means of decision trees. Knowledge-Based Systems 8 (6), pp. 333 – 344. Note: Knowledge-based neural networks Cited by: §2.
  • [16] P. Kontschieder, M. Fiterau, A. Criminisi, and S. Rota Bulo (2015-12) Deep neural decision forests. In The IEEE International Conference on Computer Vision (ICCV), Cited by: §1, §2, Table 1.
  • [17] R. Krishnan, G. Sivakumar, and P. Bhattacharya (1999) Extracting decision trees from trained neural networks. Pattern Recognition 32 (12), pp. 1999 – 2009. Cited by: §2.
  • [18] A. Krizhevsky (2009) Learning multiple layers of features from tiny images. Technical report . Cited by: §1, §2, item 1.
  • [19] Y. Le and X. Yang (2015) Tiny imagenet visual recognition challenge. Cited by: §1, §2, item 1.
  • [20] Y. LeCun, C. Cortes, and C. Burges (2010) MNIST handwritten digit database. ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist 2. Cited by: §2.
  • [21] M. McGill and P. Perona (2017) Deciding how to decide: dynamic routing in artificial neural networks. In ICML, Cited by: §2.
  • [22] G. A. Miller (1995-11) WordNet: a lexical database for english. Commun. ACM 38 (11), pp. 39–41. External Links: ISSN 0001-0782, Link, Document Cited by: §3.2.
  • [23] V. N. Murthy, V. Singh, T. Chen, R. Manmatha, and D. Comaniciu (2016-06) Deep decision network for multi-class image classification. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2, Table 1.
  • [24] J. C. Peterson, P. Soulos, A. Nematzadeh, and T. L. Griffiths (2018) Learning hierarchical visual representations in deep neural networks using hierarchical linguistic labels. CoRR abs/1805.07647. Cited by: §2.
  • [25] V. Petsiuk, A. Das, and K. Saenko (2018) RISE: randomized input sampling for explanation of black-box models. In Proceedings of the British Machine Vision Conference (BMVC), Cited by: §2.
  • [26] M. T. Ribeiro, S. Singh, and C. Guestrin (2016) ”Why should I trust you?”: explaining the predictions of any classifier. In Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, San Francisco, CA, USA, August 13-17, 2016, pp. 1135–1144. Cited by: §2.
  • [27] R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra (2017) Grad-cam: visual explanations from deep networks via gradient-based localization. In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 618–626. Cited by: §2.
  • [28] K. Simonyan, A. Vedaldi, and A. Zisserman (2013) Deep inside convolutional networks: visualising image classification models and saliency maps. arXiv preprint arXiv:1312.6034. Cited by: §2.
  • [29] C. Siu (2019) Transferring tree ensembles to neural networks. In Neural Information Processing, pp. 471–480. Cited by: §2.
  • [30] J. T. Springenberg, A. Dosovitskiy, T. Brox, and M. A. Riedmiller (2014) Striving for simplicity: the all convolutional net. CoRR abs/1412.6806. Cited by: §2.
  • [31] M. Sundararajan, A. Taly, and Q. Yan (2017) Axiomatic attribution for deep networks. International Conference on Machine Learning (ICML) 2017. Cited by: §2.
  • [32] M. Tan and Q. V. Le (2019)

    Efficientnet: rethinking model scaling for convolutional neural networks

    .
    arXiv preprint arXiv:1905.11946. Cited by: item 2.
  • [33] R. Tanno, K. Arulkumaran, D. C. Alexander, A. Criminisi, and A. Nori (2019) Adaptive neural trees. Cited by: Table 1.
  • [34] R. Teja Mullapudi, W. R. Mark, N. Shazeer, and K. Fatahalian (2018-06) HydraNets: specialized dynamic architectures for efficient inference. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2.
  • [35] A. Veit and S. Belongie (2018-09) Convolutional networks with adaptive inference graphs. In The European Conference on Computer Vision (ECCV), Cited by: §2.
  • [36] S. Zagoruyko and N. Komodakis (2016) Wide residual networks. arXiv preprint arXiv:1605.07146. Cited by: item 2.
  • [37] M. D. Zeiler and R. Fergus (2014) Visualizing and understanding convolutional networks. In European Conference on Computer Vision (ECCV), pp. 818–833. Cited by: §2.
  • [38] J. Zhang, Z. Lin, J. Brandt, X. Shen, and S. Sclaroff (2016) Top-down neural attention by excitation backprop. In European Conference on Computer Vision (ECCV), pp. 543–559. Cited by: §2.
  • [39] Q. Zhang, Y. Yang, H. Ma, and Y. N. Wu (2019-06) Interpreting cnns via decision trees. In The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), Cited by: §2.