Unlike conventional A/B testing where individuals are randomly assigned to a treatment condition, internet companies are increasingly using heterogeneous treatment effects models (HTE) and personalized policies to decide, for each individual, the best predicted treatment (garcin2013personalized; tougucc2020hybrid). However, to flexibly model heterogeneity in treatment and outcome surfaces, state-of-the-art HTE models used for personalization (wager2018estimation; shi2019adapting; kennedy2020optimal)
frequently leverage black-box, uninterpretable base learners such as gradient boosted trees and neural networks. To generate a HTE prediction, some of them combine outputs from multiple models via stacking, weighting, etc.(kunzel2019metalearners; nie2021quasi; montoya2021optimal). Moreover, their treatment effect predictions are sometimes used as inputs to additional black-box policy learning models (imai2011estimation; mcfowland2020prescriptive).
There are several challenges with these approaches. Firstly, when there are multiple (not just binary) outcomes, the creation of a policy is no longer as trivial as assigning the treatment group that maximizes a single treatment effect. Secondly, as dataset size and the number of features used for HTE modeling or policy learning increases, the cost of maintaining such models in deployment increases (NIPS2015_86df7dcf). Finally, the black-box nature of HTE models and resulting policies makes them difficult to interpret, which can deter their uptake for critical applications.
This paper describes methods for: (1) learning explanations of HTE models, to better understand such models; (2) generating interpretable policies that can perform on par with or even better than black-box policies. Our explanations and interpretable policies are constructed from if-else rules, using individual features as inputs (Figure 1). We demonstrate a variety of different ways to learn interpretable policies, from methods that act on HTE model predictions to methods that work on conditional average treatment effects. We also propose ways to ensemble multiple interpretable policies while still remaining interpretable.
Crucially, all of our proposed methods work in the general setting of more than two treatment groups and multiple outcomes, which is needed in production environments at internet companies. For example, a new website feature may increase user engagement at the expense of revenue. We do so by leveraging techniques such as model distillation (bucilua2006compression; hinton2014distilling) with multitask (caruana1997multitask) and multiclass student models, and linear scalarization of multiple outcomes from the multi-objective optimization literature (gunantara2018review).
2 Related Work
We do not attempt to comprehensively survey the state of HTE estimation since our contribution is agnostic to the choice of HTE model. Most modern HTE estimation methods are compatible with our proposed methods.
Subgroup finding. Our work on interpreting HTE models by surfacing segments with heterogeneity produces a representation analogous to subgroup finding methods. These methods are divided into: (1) statistical tests to confirm or deny a set of pre-defined hypotheses about subgroups (assmann2000subgroup; song2007method); (2) methods that aim to discover subgroups directly from data (imai2013estimating; chen2015prim; wang2017causal; nagpal2020interpretable); (3) methods that act on inferred individual-level treatment effects to discover subgroups foster2011subgroup; lee2020robust. However, the comparison of subgroup finding algorithms is still an open question. Loh et al. empirically compared 13 different subgroup-finding algorithms (loh2019subgroup), finding that no one algorithm satisfied all the desired properties of such algorithms.
Policy learning. The problem of policy learning from observational/experimental data has received significant attention due to its wide applicability (beygelzimer2009offset; qian2011performance; zhao2012estimating; swaminathan2015counterfactual; kitagawa2018should; zhou2018offline; kallus2018policy), but their interpretability and scalability has received less attention. Several existing works (amram2020optimal; athey2021policy; jo2021learning) do construct interpretable tree policies based on counterfactual estimation. However, they focus on constructing exactly optimal trees, which is prohibitively slow in a large scale and unsuitable for a production environment.
Interpretability via distillation and imitation learning.
Interpretability via distillation and imitation learning.Black-box classification and regression models have been distilled into interpretable models such as trees, rule lists, etc. (craven1995extracting; frosst2017soft; tan2018distill; tan2018transparent). craven1995extracting
distilled a neural net to a decision tree and evaluated the explanation in terms of fidelity, accuracy, and complexity.namkoong2020distilledalharin2020reinforcement for a survey. These methods can be roughly divided into those that leverage the interpretable representation provided by trees to approximate states and actions (bastani2018verifiable; vasic2019moet), saliency and attention methods (ziyu2016dueling; greydanus2018visualizing; qureshi2017show), and distillation (czarnecki2019distilling). We also use distillation techniques, but in a non-sequential setting.
3.1 Problem Setting and Notation
Suppose we have a dataset of points and features with treatment groups, control group, and outcomes. Let denote the treatment group indicator for the -th individual ( means the individual is in the control group). Let denote the set of individuals in the -th treatment group, . Following the potential outcome framework (neyman1923applications; rubin1974estimating), we assume for each individual and outcome there are potential outcomes . For each individual , we observe features and outcomes where , the observed potential outcome (out of potential outcomes). We suppose we have access to predicted outcomes , where is the predicted -th outcome for individual assuming treatment group . These could be obtained directly from the HTE model or together with estimates of the baseline surface.
We focus on deterministic policies, :
, which map feature vectors to treatment groups. An individual with feature vectorwill be given treatment . A segment, , is a set of indices, representing a set of points, , treatment assignments, along with their predicted outcomes, for . Throughout this work, we refer to lists of segments, which form a partition of the input indices, i.e. there is no index overlap between the segments.
3.2 Preliminaries: Combining Multiple Outcomes
Throughout this work, we combine multiple outcomes into a single outcome via linear scalarizations, parameterized by a set of weights . Thus, for each individual , the multiple observed outcomes becomes a single outcome: Y_i = Y_i^(W_i) = c_1 Y_i1^(W_i) + ⋯+ c_J Y_iJ^(W_i) = c_1 Y_i1 + ⋯+ c_J Y_iJ We also combine predicted outcomes in the same way: . Our proposed methods take as inputs and assume a value 1 (equal weighting on outcomes) if not provided. In practice, it is common to either tune the weights (letham2019bayesian) or handpick weights to manage particular preferences or tradeoffs.
3.3 Explaining Heterogeneous Treatment Effect Models
Multi-task decision trees
(MTDT) extend single-task decision trees by combining the prediction loss on each outcome, across multiple outcomes, into a single scalar loss function. This simple representation is effective and suitable for locating segments where individuals present heterogeneity across multiple outcomes. In this tree, each task (label) is the predicted treatment effect for an outcome, and each node is a segment identified to have elevated or depressed treatment effects across multiple outcomes. We learn these trees using distillation techniques. The resulting method is calledDistilling HTE model predictions (Distill-HTE). Figure 2 provides a concrete example.
Distillation. To learn an MTDT model in a HTE model-agnostic fashion, we leverage model distillation techniques, taking the HTE model as the teacher and the MTDT model as a student. After obtaining predicted outcomes from trained HTE models, where is the predicted -th outcome for individual assuming treatment group , we construct pairwise predicted treatment effects , i.e. individual ’s predicted treatment effect for outcome for a particular treatment contrast (analogous definitions will hold for all treatment contrasts). Let be the prediction function of an MTDT model using the following distillation training objective that minimizes the mean square error between the HTE model predicted treatment effects and :
is the MTDT model’s prediction for just outcome , accompanied by , a weight that encodes how much outcome contributes to the overall loss, as reviewed in Section 3.2. When , as we will assume by default, each outcome contributes equally to the overall distillation loss.
To improve the robustness of the subgroups, we introduce three elements based on counterfactual reasoning ideas: (1) Terminal nodes without sufficient overlap (i.e. enough treatment AND control points) are post-pruned from the tree, and the prediction regenerated; (2) Confidence intervals are provided for treatment effects within each node; (3) Honesty(athey2016recursive): splits and predictions are determined on different data splits.
From treatment effects to policies. With treatment groups, 1 control group, and outcomes, the proposed method learns MTDT models, where each model predicts treatment effects of outcomes for each treatment group compared to control. One MTDT model can then be used to generate a policy that indicates, for each segment, which of the pair of treatment groups should be assigned to maximize treatment effects. However, generating a single policy still requires combining multiple MTDT, and it is not immediately obvious the best way to do so. This motivates our proposal of directly learning a single policy that applies to multiple treatment groups and outcomes.
3.4 Interpretable Policy Learning
To generate a single interpretable policy that applies to multiple treatment groups and outcomes, we now describe different approaches, all using if-else rules representations.
3.4.1 Approaches based on HTE model predictions
An already trained HTE model from which predicted outcomes can be obtained is a prerequisite for the approaches described in this subsection. We describe two ways of leveraging these predictions: GreedyTreeSearch-HTE and Distill-Policy.
Greedy tree search from HTE model predictions (GreedyTreeSearch-HTE). In this approach, we directly solve the following optimization problem in the space of trees with pre-determined maximum tree depth, assuming without loss of generality that higher outcomes are better:
This can be seen as a cost-sensitive classification problem, where the cost of assigning a point to a group is the negative value of the predicted potential outcome (foundation_cost_sensitive). To achieve a scalable implementation, we solve the optimization problem greedily instead of resorting to exact tree search over the whole space of possible trees. Define and . The detailed implementation can be found in Algorithm 1.
This approach is similar to zhou2018offline and amram2020optimal which respectively solve an exact tree search problem over predicted outcomes and find the optimal tree using coordinate ascent. Due to scalability concerns and the need for faster test-time inference, we take a greedy approach. On a sample of 100k data points with 12 features and three treatment groups, the zhou2018offline method took 2.5 hours to run. With a quadratic runtime in the number of data points, and the amount of data at large internet companies sometimes exceeding 100 million data points, it is impossible to run it at a scale suitable for internet companies.
Distilling policy derived from HTE model predictions (Distill-Policy). In this approach, we start from the naive policy implied by the outcome predictions , i.e.
. Then, we train a decision tree classifier on the dataset, and output the segments from the decision tree classifier. The resulting policy assigns all individuals in the same segment to the majority treatment group of individuals in that segment.
3.4.2 Approaches Not Based on HTE Model Predictions
The non-HTE-model based methods proposed in this section are useful when we are not able to train an accurate HTE model. Here we aim to generate an interpretable policy without training a HTE model, by leveraging average treatment effects. Concretely, we define the segment average treatment effect for segment and treatment as:
where denote the set of individuals in the -th treatment group. We then define a splitting metric , which considers only the most responsive treatment in that segment. We build a tree and consider binary splits that improve this splitting metric.
Algorithm 3 and Algorithm 4 describe two different implementations for this. In the greedy implementation, a split is only considered if both the left and right child segment have an improvement over the parent segment. In the iterative implementation, we run several iterations of splitting. For each iteration we keep the best segment, excluding it in the next run. A split is considered as long as one of the child segments improves the outcome compared to the parent node. One pitfall with this iterative implementation is that segments are not necessarily disjoint in feature space, so one individual could appear in several segments. We resolve this by always assigning the individual to the first found segment.
For both methods, if no segments are found because no eligible split exists, the policy defaults to assigning all individuals to the treatment group with highest average treatment effect. The resulting policy is: if i and if , we assign the individual to the control or default treatment.
These non-model based methods can be viewed as attempts to segment the feature space using empirical conditional average treatment effects; similar efforts have appeared in the literature (e.g. dwivedi2020stable).
3.4.3 Approaches Based on Ensembling Different Interpretable Policies
We can already generate interpretable tree-based policies using the methods described above. However, different policies may exhibit different strengths in different regions, and simply training these trees with deeper depth does not necessarily improve the resulting policy in our experiments. We leverage ensemble learning to identify such regions, with the hope of generating a better policy. Concretely, suppose we have access to policies . We want to train an ensemble policy that uses all or a subset of the policies while still remaining interpretable. We introduce two ways of doing so. For ease of notation, we assume the trained policies were obtained from another split of the dataset and we can safely use as the validation set on which we learn the ensemble policy.
Ensemble based on Uniform Exploration (GUIDE-UniformExplore): This method is inspired by the explore-exploit paradigm in the contextual bandits literature (see slivkins2021introduction for a review). To perform this offline, we use HTE outcome predictions when the observed outcome is not revealed in the dataset. Algorithm 2 provides the implementation. The ensemble policy is generated using Algorithm 1 but with the HTE predictions replaced by .
Ensemble using Offline Evaluation (GUIDE-OPE): We aim to find one feature split such that the left and right children uses a different candidate policy. We do so using offline policy evaluation (OPE; see Algorithm 5). Suppose we split dataset into and , let be the ensemble policy that applies to and to . To create , we use exhaustive search to find the optimal feature split and candidate policies that solves . We assign to individual policy if individual and otherwise.
Both methods return trees that we call guidance trees; see Figure 3 for an example. While there exists many other ways to ensemble policies, such as SuperLearner (montoya2021optimal), we do not consider these methods as they result in non-interpretable policies.
4.1 Comparing Explanations of HTE Models
|Data||Method||PEHE||Within-subgroup var||Between-subgroup var|
|Synthetic COVID lee2020robust||Virtual Twins RF||3.054 0.960||2.009 0.731||0.518 0.715|
|Virtual Twins GBDT||1.903 0.660||0.164 0.127||2.362 0.216|
|Distill-HTE||0.648 0.080||0.337 0.108||1.986 0.282|
|R2P||3.265 0.671||2.355 0.120||0.180 0.186|
|T-Learner GBDT||1.902 0.654||–||–|
|T-Learner DT||4.777 1.081||–||–|
|Synthetic A athey2016recursive||Virtual Twins RF||0.433 0.050||0.051 0.008||0.174 0.023|
|Virtual Twins GBDT||0.151 0.016||0.010 0.001||0.255 0.027|
|Distill-HTE||0.189 0.033||0.037 0.009||0.241 0.042|
|R2P||0.719 0.200||0.211 0.077||0.036 0.071|
|T-Learner GBDT||0.151 0.016||–||–|
|T-Learner DT||0.616 0.064||–||–|
Test-set performance of subgroup finding methods on synthetic datasets. For PEHE and variance, lower is better. Best method for each column in bold.
We compare the Distill-HTE method proposed in Section 3.3 against several other subgroup finding methods that take the predictions of HTE models as input: (1) Virtual Twins (VT) (foster2011subgroup); (2) R2P (lee2020robust). We deliberately restrict to such post-hoc methods, due to shared motivation of explaining already-trained HTE models. We also compare to a black-box model: a T-Learner that does not find subgroups but rather provides one prediction per individual.
Setup: The first dataset, Synthetic COVID, was proposed by lee2020robust and uses patient features as in an initial clinical trial for the Remdesivir drug, but generates synthetic outcomes where the drug reduces the time to improvement for patients with a shorter period of time between symptom onset to starting the trial. The second dataset, Synthetic A, was proposed in athey2016recursive
. For each dataset we generate ten train-test splits, on which we compute the mean and standard deviation of estimates. For R2P, we use the implementation of R2P provided by the authors; For VT we used our own implementation.
Evaluation: On synthetic data where ground truth treatment effects are available, we report the Precision in Estimation of Heterogeneous Effect (PEHE) (hill2011bayesian), defined as . Unlike bias and RMSE metrics computed relative to ground truth treatment effects, PEHE requires accurate estimation of both counterfactual and factual outcomes (johansson2016learning). We also compute between- and within- subgroup variance.
Results: Table 1 presents the results. We make a few observations: (1) Black-box vs. white-box: As expected, GBDT T-Learners perform well as they do not have interpretability constraints, unlike all the other methods (VT, Distill-HTE, R2P), all of which modify standard decision trees while still remaining visualize-able as a tree. Yet, Distill-HTE tends to be far more accurate than other decision tree approaches, in terms of PEHE. (2) Optimization criterion: R2P, the only method of those presented here, that considers not only homogeneity within subgroups but also heterogeneity between subgroups, has the lowest between-subgroup variance. Other methods that do not try to reduce heterogeneity between subgroups do not fare so well on this metric. On the other hand, they fare better than R2P at minimizing within subgroup variance, because they do not consider the tradeoffs between minimizing within and between subgroup variance, unlike R2P. However, R2P does this at the expense of PEHE. (3) HTE model class: The choice of HTE model matters, with Virtual Twins not performing as well when using an RF HTE model compared to GBDT HTE model. Similarly, GBDT T-Learners perform better than DT T-Learners. (4) The impact of distillation: While the T-Learner DT model did not perform well, being worst in terms of PEHE on all datasets, Virtual Twins RF, Virtual Twins GBDT and Distill-HTE that train modified decision trees have a marked improvement over T-Learner DT, suggesting that distilling a complex GBDT or RF teacher rather than learning a tree directly is beneficial, which agrees with the existing distillation literature.
|Semi-synthetic IHDP hill2011bayesian||redAssign all to treatment 0||111.131 10.279|
|redAssign all to treatment 1||21.537 3.672|
|blueNo-HTE (Greedy and Iterative)*||21.537 3.672|
|greenGUIDE-UniformExplore (guide depth=1)||62.306 11.039|
|greenGUIDE-UniformExplore (guide depth=2)||56.534 3.116|
|greenGUIDE-OPE (guide depth=1)||65.555 9.676|
|Synthetic A athey2016recursive||redAssign all to treatment 0||62.162 2.232|
|redAssign all to treatment 1||59.021 2.178|
|blueNo-HTE (Greedy and Iterative)*||59.021 2.178|
|greenGUIDE-UniformExplore (guide depth=1)||0.016 0.016|
|greenGUIDE-UniformExplore (guide depth=2)||0.021 0.014|
|greenGUIDE-OPE (guide depth=1)||0.015 0.011|
|Email Marketing Hillstrom||redAssign all to treatment 0||48840.305 1887.596|
|redAssign all to treatment 1||41093.311 3377.773|
|redAssign all to treatment 2||33976.605 705.053|
|greenGUIDE (UniformExplore and OPE)||33976.605 705.053|
* denotes no segments were found, and the resulting policy assigned all individuals to one treatment group.
4.2 Comparing Policy Learning Methods
We compare the methods proposed in Section 3.4 to (1) a black-box policy: training a T-Learner HTE model, then assigning each individual to the treatment group with best predicted treatment effects; (2) a policy which chooses the treatment for each unit uniformly at random.
Setup: Besides the synthetic datasets described in 4.1, we use other publicly-available datasets. The IHDP dataset (hill2011bayesian) studied the impact of specialized home visits on infant cognition using mother and child features. A multiple-outcome dataset, Email Marketing (Hillstrom) has 64k points, and visits, conversions, and money spent outcomes (details on how we constructed potential outcomes in Appendix). We combine multiple outcomes as explained in Section 3.2. The ensemble policies (GUIDE-UniformExplore, GUIDE-OPE) are based on GreedyTreeSearch-HTE and Distill-Policy – selected because of their individual performance.
Evaluation: To evaluate the different policies , we define a notion of regret:
where is the treatment group prescribed by that particular policy for individual and is individual ’s potential outcome under treatment .
Results: Table 2 presents the results. We make a few observations: (1) In many cases, at least one interpretable policy (GreedyTreeSearch-HTE, Distill-Policy, No-HTE, ensembled policies) improved over or is not far behind the black-box policy. (2) Personalization: In datasets like IHDP where the benefit from personalization was limited, as the majority of points have positive treatment effects, No-HTE-Greedy and No-HTE-Iterative were able to pick up on this best policy and assign all points to one treatment. (3) The impact of ensembling: In general, ensembled policies GUIDE-UniformExplore and GUIDE-OPE improved over their constituent policies GreedyTreeSearch-HTE and Distill-Policy. However, ensemble methods with deeper guidance trees are not necessarily better. Hence, the tradeoff between one more level of guidance, and reduced interpretability, should be considered.
4.3 Ensembling Interpretable Policies
We now display several of the interpretable policies, and show how they can be ensembled while remaining interpretable. On Synthetic A, a simple dataset with only two features, the following policies were obtained:
GreedyTreeSearch-HTE: If feature0 -0.019, assign control. Otherwise, assign treatment.
Distill-Policy: If feature0 -0.02, assign control. Otherwise, assign treatment.
We applied the GUIDE-UniformExplore and GUIDE-OPE approaches to learn guidance trees: trees that ensemble these two policies while still remaining interpretable. While the GUIDE-OPE guidance tree suggested simply following Distill-Policy, the GUIDE-UniformExplore guidance tree suggested:
GUIDE-UniformExplore: If feature1 -0.426, follow Distill-Policy.
Otherwise, follow GreedyTreeSearch-HTE.
By leveraging feature1 which had so far not been used in the individual policies, GUIDE-UniformExplore returned the lowest regret of all methods, including a regret lower than that of the individual policies it ensembled, GreedyTreeSearch-HTE and Distill-Policy (Table 2
). Interestingly, even if the individual policies GreedyTreeSearch-HTE and Distill-Policy were grown to greater depth, feature1 was not still not selected. While we do not always see gains from ensembling, in some datasets guidance trees can correct for the greediness of tree learning. While typical ways of ensembling trees (e.g. random forest) reduces interpretability, depth-constrained guidance trees, added to the top of an interpretable policy tree, merely makes the policy tree a bit deeper.
4.4 Bridging Explanations and Policies
We now present the explanations and resulting interpretable policies on the Synthetic COVID data.
Interpreting HTE model: Figure 4 presents segments found by Distill-HTE on the synthetic COVID data. The segment with the most negative predicted treatment effect (first segment; red color), at -0.0304 +- 0.0010, covers individuals who started taking the drug between 4.5 and 10.5 days after the onset of symptoms and had Aspartete Aminotransferase and Lactate Dehydrogenase levels within normal ranges (medicinenet), suggesting that they were not extremely sick. It is therefore not surprising that they were not predicted to benefit as much from the drug.
On the other hand, the segment with the most positive predicted treatment effect (last segment; green color) covers individuals who started taking the drug soon ( 4.5 days) after exhibiting COVID symptoms. These individuals were predicted to benefit the most from the drug, with treatment effect 0.0187 +- 0.0006. This agrees with the finding that the Remdesivir drug results in a faster time to clinical improvement for the patients with shorter time from symptom onset to starting trial (wang2020remdesivir).
Policy Generation: A simple policy can be derived from the Distill-HTE segments, with all red segments assigned to not receiving the drug (as they are predicted to not benefit from it) and all green segments assigned to receiving the drug. This policy is different from the interpretable policies learned directly (Table 2). For example, the GreedyTreeSearch-HTE policy is extremely simple: if days from symptoms onset to starting trial 11, assign to receive the drug. Otherwise, assign to not receive the drug. Another example is No-HTE-Greedy and No-HTE-Iterative whose rather stringent splitting criteria (Equation 3) did not find enough heterogeneity worth assigning different treatment groups to, and assigned all individuals to receive the drug.
5 Discussion and Conclusion
The motivation for this work was three-fold. (1) HTE models sometimes overfit on extremely large datasets, due to having a large number of data points, extremely noisy outcomes, noisy features, etc. (2) Interpretable policies can avoid much of the tech-debt ML-based policies are known to incur (sculley2015hidden). (3) Scalability constraints excluded many optimal tree-based policy learning algorithms from deployment in production environments, as illustrated in Section 3.4.
The interpretable policy methods proposed in this paper are not theoretically optimal, a shortcoming we acknowledge. Nonetheless, we believe that the comparison to sensible baselines, such as policies derived from T-Learner HTE models, and the ability of some policies to achieve low regret on simple synthetic datasets, demonstrates their merit.
In real, large-scale datasets where the potential benefit from personalization was higher, No-HTE-Greedy and No-HTE-Iterative suggested more personalized policies, especially when feature selection was performed beforehand. This is because the splitting criteria3 included all features. If feature selection is not performed, this, together with using sample average treatment effects as our splitting criteria makes it easier to find a good split on the training dataset by chance but not on the test set. Further work is needed to realize the full potential of these methods.
In practice, we have seen that the choice between deriving policies from black-box HTE models and directly learning ground-up interpretable policies depends on the ability of the individuals utilizing such policies to maintain HTE models in production, and the need to explain how exactly personalization is happening, especially for applications with where personalization based on sensitive features incurs disparate treatment unfairness (lal2020fairness). While HTE models are resource and maintenance intensive, the ability to continuously retrain the model allows for adjustment to a dynamic user population. Conversely, interpretable policies are easy to implement and maintain, but may not perform best over the long-term without policy regeneration as the features of the user population shift.
Appendix A Data Generation Details
The Email Marketing dataset (Hillstrom) is a real dataset that came from a randomized experiment, where customers were randomized into receiving one of three treatments. To generate potential outcomes, for each individual in treatment group , we searched for the 5-nearest-neighbors in treatment group , averaging their outcomes to get the potential outcome for individual for treatment group . To compute distance between individuals, we used Euclidean distance in feature space.
Appendix B Training Details
Unless otherwise mentioned, the HTE models we train are T-learners consisting of GBDT base learners. In general, we use 40% of the data as the test set on which we repot results, and 30% of the remaining 60% as the validation set. We learn individual policies on the training set. When learning ensemble policies, we learn the ensemble on the validation set.