WildWood: a new Random Forest algorithm

09/16/2021
by   Stéphane Gaïffas, et al.
0

We introduce WildWood (WW), a new ensemble algorithm for supervised learning of Random Forest (RF) type. While standard RF algorithms use bootstrap out-of-bag samples to compute out-of-bag scores, WW uses these samples to produce improved predictions given by an aggregation of the predictions of all possible subtrees of each fully grown tree in the forest. This is achieved by aggregation with exponential weights computed over out-of-bag samples, that are computed exactly and very efficiently thanks to an algorithm called context tree weighting. This improvement, combined with a histogram strategy to accelerate split finding, makes WW fast and competitive compared with other well-established ensemble methods, such as standard RF and extreme gradient boosting algorithms.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

page 3

08/17/2020

To Bag is to Prune

It is notoriously hard to build a bad Random Forest (RF). Concurrently, ...
10/14/2014

Enhanced Random Forest with Image/Patch-Level Learning for Image Understanding

Image understanding is an important research domain in the computer visi...
01/21/2021

Crossbreeding in Random Forest

Ensemble learning methods are designed to benefit from multiple learning...
06/25/2019

AMF: Aggregated Mondrian Forests for Online Learning

Random Forests (RF) is one of the algorithms of choice in many supervise...
03/02/2021

Slow-Growing Trees

Random Forest's performance can be matched by a single slow-growing tree...
08/31/2016

hi-RF: Incremental Learning Random Forest for large-scale multi-class Data Classification

In recent years, dynamically growing data and incrementally growing numb...
04/10/2018

Hyperparameters and Tuning Strategies for Random Forest

The random forest algorithm (RF) has several hyperparameters that have t...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

This paper introduces WildWood (WW), a new ensemble method of Random Forest (RF) type [9]. The main contributions of the paper and the main advantages of WW are as follows. Firstly, we use out-of-bag samples (trees in a RF use different bootstrapped samples) very differently than what is done in standard RF [43, 7]. Indeed, WW uses these samples to compute an aggregation of the predictions of all possible subtrees of each tree in the forest, using aggregation with exponential weights [14]. This leads to much improved predictions: while only leaves contribute to the predictions of a tree in standard RF, the full tree structure contributes to predictions in WW. An illustration of this effect is given in Figure 1 on a toy binary classification example, where we can observe that subtrees aggregation leads to improved and regularized decision functions for each individual tree and for the forest.

Figure 1: WW decision functions illustrated on a toy dataset (left) with subtrees aggregation (top) and without it (bottom). Subtrees aggregation improves trees predictions, as illustrated by smoother decision functions in the top compared with the bottom, improving overall predictions of the forest (last column).

We further illustrate in Figure 2 that each tree becomes a stronger learner, and that excellent performance can be achieved even when WW uses few trees.

Figure 2:

Mean test AUC and standard-deviations (

-axis) using 10 train/test splits for WW and scikit-learn’s implementations of RF [43] and Extra Trees [32]

, using default hyperparameters, on several datasets. Thanks to subtrees aggregation, WW improves these baselines, even with few trees (

-axis is the number of trees).

A remarkable aspect of WW is that this improvement comes only at a small computational cost, thanks to a technique called “context tree weighting”, used in lossless compression or online learning to aggregate all subtrees of a given tree [73, 72, 34, 14, 50]. Also, the predictions of WW do not rely on MCMC approximations required with Bayesian variants of RF [21, 26, 22, 66], which is a clear distinction from such methods.

Secondly, WW uses feature binning (“histogram” strategy), similarly to what is done in extreme gradient boosting (EGB) libraries such as XGBoost 

[18], LightGBM [38] and CatBoost [56, 28]. This strategy helps to accelerate computations in WW compared with standard RF algorithms, that typically require to sort features locally in nodes and try a larger number of splits [43]. This combination of subtrees aggregation and of the histogram strategy makes WW comparable with state-of-the-art implementations of EGB libraries, as illustrated in Figure 3.

Figure 3: Test AUC (top) and training time (bottom) of WW compared with very popular EGB libraries (after hyperoptimization of all algorithms, see Section 4 for details). WW’s performance, which uses only 10 trees in this display, is only slightly below such strong baselines, but is faster (training times are on a logarithmic scale) on the considered datasets.

Moreover, WW supports optimal split finding for categorical features and missing values, with no need for particular pre-processing (such as one-hot encoding 

[18] or target encoding [56, 28]

). Finally, WW is supported by some theoretical evidence, since we prove that for a general loss function, the subtrees aggregation considered in WW leads indeed to a performance close to that of the best subtree.

Related works.

Since their introduction [9], RF algorithms have become one of the most popular supervised learning algorithm thanks to their ease of use, robustness to hyperparameters [7, 55] and applicability to a wide range of domains, recent examples include bioinformatics [57], genomic data [19], predictive medicine [65, 1], intrusion detection [20], astronomy [35], car safety [67], differential privacy [53], COVID-19 [64] among many others. A non-exhaustive list of developments about RF methodology include soft-pruning [12], extremely randomized forests [32], decision forests [24], prediction intervals [60, 74, 13], ranking [76], nonparametric smoothing [70], variable importance [44, 37, 45], combination with boosting [33], generalized RF [3], robust forest [41], global refinement [59], online learning [39, 50] and results aiming at a better theoretical understanding of RF [6, 5, 31, 2, 63, 61, 62, 49, 48, 75].

A recent empirical study [75]

suggests that tree depth limitation in RF is an effective regularization mechanism which improves performance on low signal-to-noise ratio datasets. Tree regularization is usually performed by pruning methods such as CCP, REP or MEP 

[58, 11]

. Although they are fairly effective at reducing tree over-fitting, these methods are mostly based on heuristics so that little is known about their theoretical properties. A form of soft-pruning was earlier proposed by 

[12] and referred to as tree smoothing. The latter efficiently computes predictions as approximate Bayesian posteriors over the set of possible prunings, however, the associated complexity is of the order of the tree-size, which makes the computation of predictions slow. In [50], an improvement of Mondrian Forests [39]

is introduced for online learning, using subtrees aggregation with exponential weights, which is particularly convenient in the online learning setting. However, this paper considers only the online setting, with purely random trees (splits are not optimized using training data), leading to poor performances compared with realistic decision trees. In WW, we use a similar subtrees aggregation mechanism for batch learning in a different way: we exploit the bootstrap, one of the key ingredients of RF, which provides in-the-bag and out-of-bag samples, to perform aggregation with exponential weights, together with efficient decision trees grown using the histogram strategy.

Extreme boosting algorithms are another type of ensemble methods. XGBoost [18] provides an extremely popular scalable tree boosting system which has been widely adopted in industry. LightGBM [38] introduced the “histogram strategy” for faster split finding, together with clever downsampling and features grouping algorithms in order to achieve high performance in reduced computation times. CatBoost [55] is another boosting library which pays particular attention to categorical features using target encoding, while addressing the potential bias issues associated to such an encoding.

Limitations.

Our implementation of WW is still evolving and is not yet at the level of maturity of state-of-the-art EGB libraries such as [18, 38, 55]. It does not outperform such strong baselines, but proposes an improvement of RF algorithms, and gives an interesting balance between performance and computational efficiency.

2 WildWood: a new Random Forest algorithm

We consider batch supervised learning, where data comes as a set of i.i.d training samples for

with vectors of numerical or categorical features

and . Our aim is to design a RF predictor computed from training samples, where is the prediction space. Such a RF computes the average of randomized trees predictions following the principle of bagging [8, 51], with where

are i.i.d realizations of a random variable corresponding to bootstrap and feature subsampling (see Section 

2.1 below). Each tree is trained independently of each other, in parallel. In what follows we describe only the construction of a single tree and omit from now on the dependence on .

Feature binning.

The split finding strategy described in Section 2.2 below works on binned features. While this technique is of common practice in EGB libraries [18, 38, 56], we are not aware of an implementation of it for RF. The input matrix of features is transformed into another same-size matrix of “binned” features denoted . To each input feature is associated a set of bins, where with a hyperparameter corresponding to the maximum number of bins a feature can use (default is similarly to [38], so that a single byte can be used for entries of ). When a feature is continuous, it is binned into

bins using inter-quantile intervals. If it is categorical, each modality is mapped to a bin whenever

is larger than its number of modalities, otherwise sparsest modalities end up binned together. If a feature contains missing values, its rightmost bin in is used to encode them. After binning, each column satisfies .

2.1 Random decision trees

Let be the binned feature space. A random decision tree is a pair , where is a finite ordered binary tree and contains information about each node in , such as split information. The tree is random and its source of randomness comes from the bootstrap and feature subsampling as explained below.

Finite ordered binary trees.

A finite ordered binary tree is represented as a finite subset of the set of all finite words on . The set is endowed with a tree structure (and called the complete binary tree): the empty word is the root, and for any , the left (resp. right) child of is (resp. ). We denote by the set of its interior nodes and by the set of its leaves, both sets are disjoint and the set of all nodes is .

Splits and cells.

The split of each is characterized by its dimension and a subset of bins . We associate to each a cell which is defined recursively: and for each we define

When corresponds to a continuous feature, bins have a natural order and for some bin threshold ; while for a categorical split, the whole set is required. By construction, is a partition of .

Bootstrap and feature subsampling.

Let be the training samples indices. The randomization of the tree uses bootstrap: it samples uniformly at random, with replacement, elements of corresponding to in-the-bag () samples. If we denote as the indices of unique samples, we can define the indices of out-of-bag () samples as . A standard argument shows that as , known as the 0.632 rule [30]. The randomization uses also feature subsampling: each time we need to find a split, we do not try all the features but only a subset of them of size , chosen uniformly at random. This follows what standard RF algorithms do [9, 7, 43], with the default .

2.2 Split finding on histograms

For -class classification, when looking for a split for some node , we compute the node’s “histogram” for each sampled feature , each bin and label class seen in the node’s samples (actually weighted counts to handle bootstrapping and sample weights). Of course, one has , so that we don’t need to compute two histograms for siblings and , but only a single one. Then, we loop over the set of non-constant (in the node) sampled features and over the set of non-empty bins

to find a split, by comparing standard impurity criteria computed on the histogram’s statistics, such as gini or entropy for classification and variance for regression.

Bin order and categorical features.

The order of the bins used in the loop depends on the type of the feature. If it is continuous, we use the natural order of bins. If it is categorical and the task is binary classification (labels in ) we use the bin order that sorts with respect to , namely the proportion of labels in each bin. This allows to find the optimal split with complexity , see Theorem 9.6 in [10], the logarithm coming from the sorting operation, while there are possible splits. This trick is used by EGB libraries as well, using an order of statistics of the loss considered [18, 38, 56]. For -class classification with , we consider two strategies: (1) one-versus-rest, where we train trees instead of , each tree trained with a binary one-versus-rest label, so that trees can find optimal categorical splits and (2) heuristic, where we train trees and where split finding uses loops over bin orders that sort (w.r.t ) for . If a feature contains missing values, we do not loop only left to right (along bin order), but right to left as well, in order to compare splits that put missing values on the left or on the right.

Split requirements.

Nodes must hold at least one and one sample to apply aggregation with exponential weights, see Section 2.3 below. A split is discarded if it leads to children with less than or samples and we do not split a node with less than or samples. These hyperparameters only weakly impact WW’s performances and sticking to default values ( and , following scikit-learn’s [43, 54]) is usually enough (see Section 11 below).

Related works on categorical splits.

In [23], an interesting characterization of an optimal categorical split for multiclass classification is introduced, but no efficient algorithm is, to the best of our understanding, available for it. A heuristic algorithm is proposed therein, but it requires to compute, for each split, the top principal component of the covariance matrix of the conditional distribution of labels given bins, which is computationally too demanding for an RF algorithm intended for large datasets. Regularized target encoding is shown in [52] to perform best when compared with many alternative categorical encoding methods. Catboost [56] uses target encoding, which replaces feature modalities by label statistics, so that a natural bin order can be used for split finding. To avoid overfitting on uninformative categorical features, a debiasing technique uses random permutations of samples and computes the target statistic of each element based only on its predecessors in the permutation. However, for multiclass classification, target encoding is influenced by the arbitrarily chosen ordinal encoding of the labels. LightGBM [38] uses a one-versus-rest strategy, which is also one of the approaches used in WW for categorical splits on multiclass tasks. For categorical splits, where bin order depends on labels statistics, WW does not use debiasing as in [56], since aggregation with exponential weights computed on samples allows to deal with overfitting.

Tree growth stopping.

We do not split a node and make it a leaf if it contains less than or samples. The same applies when a node’s impurity is not larger than a threshold ( by default). When only leaves or non-splittable nodes remain, the growth of the tree is stopped. Trees grow in a depth-first fashion so that childs and have memory indexes larger than their parent (as required by Algorithm 1 below).

2.3 Prediction function: aggregation with exponential weights

Given a tree grown as described in Sections 2.1 and 2.2, its prediction function is an aggregation of the predictions given by all possible subtrees rooted at , denoted . While is grown using samples, we use samples to perform aggregation with exponential weights, with a branching process prior over subtrees, that gives more importance to subtrees with a good predictive performance.

Node and subtree prediction.

We define as the leaf of containing . The prediction of a node and of a subtree is given by

(1)

where is a generic “forecaster” used in each cell and where a subtree prediction is the one of its leaf containing . A standard choice for regression () is the empirical mean forecaster

(2)

where . For -class classification with and

, the set of probability distributions over

, a standard choice is a Bayes predictive posterior with a prior on equal to the Dirichlet distribution , namely the Jeffreys prior on the multinomial model , which leads to

(3)

for any , where . By default, WW uses (the Krichevsky-Trofimov forecaster [68]), but one can perfectly use any , so that all the coordinates of are positive. This is motivated by the fact that WW uses as default the log loss to assess

performance for classification, which requires an arbitrarily chosen clipping value for zero probabilities. Different choices of

only weakly impact WW’s performance, as illustrated in Appendix 11. We use samples to define the cumulative losses of the predictions of all

(4)

where is a loss function. For regression problems, a default choice is the quadratic loss while for multiclass classification, a default is the log-loss , where when using (3), but other loss choices are of course possible.

Prediction function.

Let . The prediction function of a tree in WW is given by

(5)

where the sum is over all subtrees of rooted at , where is temperature parameter and is the number of nodes in minus its number of leaves that are also leaves of . Note that is the distribution of the branching process with branching probability at each node of , with exactly two children when it branches. A default choice is for the log-loss (see in particular Corollary 1 in Section 3 below), but it can also be tuned through hyperoptimization, although we do not observe strong performance gains, see Section 11 below. The prediction function (5) is an aggregation of the predictions of all subtrees rooted at , weighted by their performance on samples. This aggregation procedure can be understood as a non-greedy way to prune trees: the weights depend not only on the quality of one single split but also on the performance of each subsequent split.

Computing from Equation (5) is computationally and memory-wise infeasible for a large , since it involves a sum over all rooted at and requires one weight for each . Indeed, the number of subtrees of a minimal tree that separates points is exponential in the number of nodes, and hence exponential in . However, it turns out that one can compute exactly and very efficiently thanks to the prior choice together with an adaptation of context tree weighting [73, 72, 34, 14].

Theorem 1.

The prediction function (5) can be written as , where satisfies the recursion

(6)

for () the path in going from to , where with and where are weights satisfying the recursion

(7)

The proof of Theorem 1 is given in Section 6 below, a consequence of this Theorem being a very efficient computation of is described in Algorithms 1 and 2 below. Algorithm 1 computes the weights using the fact that trees in WW are grown in a depth-first fashion, so that we can loop once, leading to a complexity in time and in memory usage, over nodes from a data structure that respects the parenthood order. Direct computations can lead to numerical over- or under-flows (many products of exponentially small or large numbers are involved), so Algorithm 1 works recursively over the logarithms of the weights (line 6 uses a log-sum-exp function that can be made overflow-proof).

1:  Inputs: , and losses for all . Nodes from are stored in a data structure that respects parenthood order: for any and children for , we have .
2:  for  do
3:     if  is a leaf then
4:        Put
5:     else
6:        Put
7:     end if
8:  end for
9:  return  The set of log-weights
Algorithm 1 Computation of for all .

Algorithm 1 is applied once is fully grown, so that WW is ready to produce predictions using Algorithm 2 below. Note that hyperoptimization of or , if required, does not need to grow again, but only to update for all with Algorithm 1, making hyperoptimization of these parameters particularly efficient.

1:  Inputs: Tree , losses and log-weights computed by Algorithm 1
2:  Find (the leaf containing ) and put
3:  Put (the node forecaster, such as (2) for regression or (3) for classification)
4:  while  do
5:     Put
6:     Put
7:     Put
8:  end while
9:  return  The prediction
Algorithm 2 Computation of for any .

The recursion used in Algorithm 2 has a complexity which is the complexity required to find the leaf containing : Algorithm 2 only increases by a factor the prediction complexity of a standard RF (in order to go down to and up again to along ). More details about the construction of Algorithms 1 and 2 can be found in Section 6 below.

3 Theoretical guarantees

This section proposes some theoretical guarantees on the subtrees aggregation used in WW, see (5). We say that a loss function is -exp-concave for some whenever is concave for any . We consider a fully-grown tree computed using samples and the set of samples on which is computed using (4), and we denote .

Theorem 2 (Oracle inequality).

Assume that the loss function is -exp-concave. Then, the prediction function given by (5) satisfies the oracle inequality

where the infimum is over any subtree and where we recall that is the number of nodes in minus its number of leaves that are also leaves of .

Theorem 2 proves that for a general loss function, the prediction function of WW is able to perform nearly as well as the best oracle subtree on samples, with a rate which is optimal for model-selection oracle inequalities [69] ( with a number of “experts” for a well-balanced ). Let us stress again that, while finding an oracle is computationally infeasible, since it requires to try out all possible subtrees, WW’s prediction function (5) comes at a cost comparable to that of a standard Random Forest, as explained in Section 2.3 above.

The proof of Theorem 2 is given in Section 7 below, and relies on techniques from PAC-Baysesian theory [46, 47, 15]. Compared with [50] about online learning, our proof differs significantly: we do not use results specialized to online learning such as [71] nor online-to-batch conversion [16]. Note that Theorem 2 does not address the generalization error, since it would require to study the generalization error of the random forest itself (and of the fully grown tree ), which is a topic way beyond the scope of this paper, and still a very difficult open problem: recent results [31, 2, 63, 61, 62, 49] only study stylized versions of RF (called purely random forests).

Consequences of Theorem 2 are Corollary 1 for the log-loss (classification) and Corollary 2 for the least-squares loss (regression).

Corollary 1 (Classification).

Consider -class classification () and consider the prediction function given by (5), where node predictions are given by (3) with (WW’s default), where is the log-loss and where . Then, we have

where is any constant function on the leaves of .

Corollary 2 (Regression).

Consider regression with for some and the prediction function given by (5), where node predictions are given by (2), where is the least-squares loss and where . Then, we have

where is any function constant on the leaves of .

The proofs of Corollaries 1 and 2 are given in Section 7. These corollaries motivate the default hyperparameter values of , in particular for classification.

4 Experiments

Our implementation of WildWood is available at the GitHub repository https://github.com/pyensemble/wildwood.git under the BSD3-Clause license on GitHub and available through PyPi. It is a Python package that follows scikit-learn’s API conventions, that is JIT-compiled to machine code using numba [40]. Trees in the forest are grown in parallel using joblib [36] and CPU threads, GPU training will be supported in future updates. We compare WildWood (denoted WW for trees) with several strong baselines including RF: scikit-learn’s implementation of Random Forest [54, 43] using trees; HGB: an histogram-based implementation of extreme gradient boosting (inspired by LightGBM) from scikit-learn; and several state-of-the-art and widely adopted extreme gradient boosting libraries including XGB: XGBoost [18]; LGBM: LightGBM [38] and CB: CatBoost [56, 28]. We used a 32-cores server with two Intel Xeon Gold CPUs, two Tesla V100 GPUs and 384GB RAM for the experiments involving hyperoptimization (Table 1) and used a 12-cores Intel i7 MacBook Pro with 32GB RAM and no GPU to obtain training times achievable by a “standard user” (Table 2). All experiments can be reproduced using Python scripts on the repository.

Description of the experiments.

We use publicly available and open-source datasets from the UCI repository 

[29], including small datasets (hundreds of rows) and large datasets (millions of rows), their main characteristics are given in Table 5 together with URLs in Table 6, see Section 10 below. Each dataset is randomly split into a training set (70%) and a test set (30%). We specify which features are categorical to algorithms that natively support it (HGB, LGBM, CB and WW) and simply integer-encode them, while we use one-hot encoding for other algorithms (RF, XGB). For each algorithm and dataset, hyperoptimization is performed as follows: from the training set, we use for training and

for validation and do 50 steps of sequential optimization using the Tree Parzen Estimator implemented in the

hyperopt library [4]. More details about hyperoptimization are provided in Section 9 below. Then, we refit on the whole training set with the best hyperparameters and report scores on the test set. This is performed 5 times in order to report standard deviations. We use the area under the ROC curve (AUC), for -class datasets with we average the AUC of each class versus the rest. This leads to the test AUC scores displayed in Table 1 (the same scores with standard deviations are available in Table 3). We report also in Table 2 (see also Table 4 for standard deviations) the test AUC scores obtained with default hyperparameters of all algorithms on the 5 largest considered datasets together with their training times (timings can vary by several orders of magnitude with varying hyperparameters for EGB libraries, as observed by the timing differences between Figure 3 and Table 2).

XGB LGBM CB HGB RF RF WW WW
adult 0.930 0.931 0.927 0.930 0.916 0.919 0.918 0.919
bank 0.933 0.935 0.925 0.930 0.917 0.929 0.924 0.931
breastcancer 0.991 0.993 0.987 0.994 0.974 0.978 0.992 0.992
car 0.999 1.000 1.000 1.000 0.996 0.996 0.997 0.998
covtype 0.999 0.999 0.998 0.999 0.996 0.998 0.996 0.998
default-cb 0.780 0.783 0.780 0.779 0.748 0.774 0.773 0.778
higgs 0.853 0.857 0.847 0.853 0.812 0.834 0.818 0.835
internet 0.934 0.910 0.938 0.911 0.841 0.911 0.923 0.928
kddcup 1.000 1.000 1.000 1.000 0.997 0.998 1.000 1.000
kick 0.777 0.770 0.777 0.771 0.736 0.752 0.756 0.763
letter 1.000 1.000 1.000 1.000 0.997 0.999 0.996 0.999
satimage 0.991 0.991 0.991 0.987 0.980 0.989 0.983 0.991
sensorless 1.000 1.000 1.000 1.000 1.000 1.000 1.000 1.000
spambase 0.990 0.990 0.987 0.986 0.980 0.986 0.983 0.987
Table 1: Test AUC of all algorithms after hyperoptimization on the considered datasets. Standard-deviations are reported in Table 3. We observe that WW has better (or identical in some cases) performances than RF on all datasets and that it is close to that of EGB libraries (bold is for best EGB performance, underline for best RF or WW performance).

Discussion of the results.

We observe in Table 1 that EGB algorithms, when hyperoptimized, lead to the best performances over the considered datasets compared with RF algorithms, and we observe that WW always improves the performance of RF, at the exception of few datasets for which the performance is identical. When using default hyperparameters for all algorithms, we observe in Table 2 that the test AUC scores can decrease significantly for EGB libraries while RF algorithms seem more stable, and that there is no clear best performing algorithm in this case. The results on both tables show that WW is competitive with respect to all baselines both in terms of performance and computational times: it manages to always reach at least comparable performance with the best algorithms despite only using trees as a default. In this respect, WW maintains high scores at a lower computational cost.

Training time (seconds) Test AUC
XGB LGBM CB HGB RF WW XGB LGBM CB HGB RF WW
covtype 10 3 120 14 21 3 0.986 0.978 0.989 0.960 0.998 0.979
higgs 36 30 653 85 1389 179 0.823 0.812 0.840 0.812 0.838 0.813
internet 9 4 188 8 0.4 0.3 0.918 0.828 0.910 0.500 0.862 0.889
kddcup 175 41 2193 31 208 12 1.000 0.638 0.988 0.740 0.998 1.000
kick 7 0.4 50 0.7 31 5 0.768 0.757 0.781 0.773 0.747 0.751
Table 2: Training times (seconds) of all algorithms with their default hyperparameters (no hyperoptimization) on the 5 largest considered datasets and test AUC corresponding to these training times. Test AUC scores are worse than that of Table 1, since no hyperoptimization is used. WW, which uses only 10 trees here (default number of trees), is almost always the fastest algorithm, for performances comparable to that of all baselines (bold is for best EGB training time or performance, underline for best RF or WW training time or performance). Standard deviations are reported in Table 4.

5 Conclusion

We introduced WildWood, a new Random Forest algorithm for batch supervised learning. Tree predictions in WildWood are aggregation with exponential weights of the predictions of all subtrees, with weights computed on bootstrap out-of-bag samples. This leads to improved predictions in each individual tree, at a small computational cost, since WildWood’s prediction complexity is similar to that of a standard Random Forest. Moreover, thanks to the histogram strategy, WildWood’s implementation is competitive with strong baselines including popular extreme boosting libraries, both in terms of performance and training times. Note also that WildWood has few hyperparameters to tune and that the performances obtained with default hyperparameters are usually good enough in our experiments.

WildWood’s implementation is still evolving and many improvements coming with future updates are planned, including the computation of feature importance, GPU training, distributed training (we only support single-machine training for now), among other enhancements that will further improve performances and accelerate computations. Room for improvement in WildWood comes from the fact that the overall forest prediction is a simple arithmetic mean of each tree prediction, while we could perform also exponentially weighted aggregation between trees. Future works include a WildWood-based implementation of isolation-forest [42]

, using the same subtrees aggregation mechanism with the log loss for density estimation, to propose a new algorithm for outliers detection.

Acknowledgments.

This research is supported by the Agence Nationale de la Recherche as part of the “Investissements d’avenir” program (reference ANR-19-P3IA-0001; PRAIRIE 3IA Institute). Yiyang Yu is supported by grants from Région Ile-de-France.

6 Proof of Theorem 1 and construction of Algorithms 1 and 2

The expression in Equation (5) involves sums over all subtrees of the fully grown tree (involving an exponential in the number of leaves of ). However, it can be computed efficiently because of the specific choice of the prior . More precisely, we will use the following lemma [34, Lemma 1] several times to efficiently compute sums of products. Let us recall that stands for the set of nodes of .

Lemma 1.

Let be an arbitrary function and define as

(8)

where the sum over means the sum over all subtrees of rooted at . Then, can be computed recursively as follows:

for each node .

For the sake of completeness, we include a proof of this statement.

Proof.

First, let us notice that the case is straightforward since there is only one pruning of which satisfies (recall that is the number of internal nodes and leaves in minus the number of leaves in that are also leaves of ). For the second case, we can expand by taking into account the pruning which only leaves as a leaf, the rest of the prunings can be expressed through pairs of prunings and of and respectively. Moreover, it can be shown that such a pruning satisfies , thus we get the following expansion :

This concludes the proof of Lemma 1. ∎

Let us introduce for any , so that Equation (5) writes

(9)

where the sums hold over all the subtrees of rooted at (the root of the full tree ). We will show how to efficiently compute and update the numerator and denominator in Equation (9). Note that may be written as

(10)
(11)
(12)

where we recall that

Equality (10) comes from the fact that the set of cells is a partition of by construction, and that the stopping criterion used to build ensures that each leaf node in contains at least one sample from (see Section 2.2). Equality (11) comes from the fact that the prediction of a node is constant and equal to for any .

Denominator of Equation (9).

For each node , denote

(13)

where once again the sum over means the sum over all subtrees of rooted at . We have that (12) entails

(14)

So, we can compute recursively very efficiently, using a recursion on the weights using Lemma 1 with . This leads to the recursion stated in Theorem 1, see Equation (7).

Now, we can exploit the fact that decision trees are built in a depth-first fashion in WildWood: all the nodes are stored in a “flat” array, and by construction both the child nodes and have indexes that are larger than the one of . So, we can simply loop over the array of nodes in reverse order, and compute if and otherwise: we are guaranteed to have computed and before computing . This algorithm is described in Algorithm 1. Since these computations involve a large number of products with exponentiated numbers, it typically leads to strong over- and under-flows: we describe in Algorithm 1 a version of this algorithm which works recursively over the logarithms of the weights. At the end of this loop, we end up at and have computed with a very efficient complexity. Note also that it is sufficient to store both and for all , which makes for a memory consumption.

Numerator of Equation (9).

The numerator of Equation (9) almost follows the exact same argument as the denominator, but since it depends on the input vector of features for which we want to produce a prediction, it is performed at inference time. Recall that is the sequence of nodes that leads to the leaf containing and define, for any , if , and otherwise. We have

(15)
(16)

Note that (15) comes from (12) while (16) comes from the definition of (note that a single term from the product over corresponds to since is a partition of ). We are now in position to use again Lemma 1 with . Defining

we can conclude that

(17)

and that the following recurrence holds:

(18)

This recurrence allows to compute from , but note that a direct use of this formula would lead to a complexity to produce a prediction for a single input . It turns out can we can do much better than that.

Indeed, whenever , we have by definition that and that for any descendant of , which entails by induction that for any . Therefore, we only need to explain how to compute for . This is achieved recursively, thanks to (18), starting at the leaf and going up in the tree to :

(19)

Let us explain where this comes from: firstly, one has obviously that , so that for . Secondly, we go up in the tree along and use again (18): whenever and for , we have since . This recursion has a complexity where is the number of nodes in , and is typically orders of magnitude smaller than (in a well-balanced binary tree, one has the relation ). Moreover, we observe that the recursions used in (7) and (19) only need to save both and