MultiMBNN: Matched and Balanced Causal Inference with Neural Networks

04/28/2020 ∙ by Ankit Sharma, et al. ∙ 0

Causal inference (CI) in observational studies has received a lot of attention in healthcare, education, ad attribution, policy evaluation, etc. Confounding is a typical hazard, where the context affects both, the treatment assignment and response. In a multiple treatment scenario, we propose the neural network based MultiMBNN, where we overcome confounding by employing generalized propensity score based matching, and learning balanced representations. We benchmark the performance on synthetic and real-world datasets using PEHE, and mean absolute percentage error over ATE as metrics. MultiMBNN outperforms the state-of-the-art algorithms for CI such as TARNet and Perfect Match (PM).



There are no comments yet.


page 9

page 10

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The primary goal in causal inference (CI) is to uncover the cause-effect relationship between entities, often construed as a problem of quantifying the effect of treatments on individuals. Randomised control trials, the most popular choice for obtaining causal relationships, are expensive and entail several logistical and ethical constraints. Causality is a crucial paradigm in several domains where observational data is available, such as healthcare [1], socioeconomic studies [2], advertising [3], etc. An impediment to CI in observational studies is the presence of confounding, where assignment and the response to the treatment depends on context covariates, resulting in selection bias.

In order to abate the effect of confounding, the discrepancy between the distribution of individuals receiving different treatments is minimized, in order to emulate a randomized trial. In the binary treatment case, confounding is addressed using statistical approaches such as sub-classification, weighting, imputations, and propensity score (PS) matching for unbiased per-individual causal estimates. Deep neural networks (DNN) based techniques propose counterfactual distribution modelling 

[4], and learning balancing representation to minimize selection bias [5, 6, 7]. In literature, the multiple treatments scenario is interpreted as different dosage levels of a single treatment, or being one of several treatments. For the latter case, matching and sub-classification techniques have been proposed [1], and in particular, generalized propensity score (GPS) based matching was proposed [8] as an accurate metric in the multiple treatment scenario [1]. DNN-based techniques for CI include counterfactual distribution learning [9], Gaussian process based modeling[10], and PS matching based Perfect match [11]. Here, we propose a novel DNN for counterfactual inference, which overcomes confounding by leveraging on both, GPS based matching and learning balancing representations.

Contributions: We propose a novel framework for counterfactual inference in the presence of confounding due to multiple treatments. We assume strong ignorability and no hidden confounding [5, 6, 7]

. We optimize DNN models via mini-batch stochastic gradient descent (SGD) to predict both, the factual and counterfactual response of a given individual to any one of the treatments. We propose the GPS based matching, along with learning balanced representations to address confounding. We estimate the GPS by training a predictive model and we use this GPS to


every sample within a minibatch with its nearest neighbours. In order to learn the balanced representation, we propose a loss function using the pairwise minimum mean discrepancy (MMD) metric. The novelty of this work is two-fold. First, we show that GPS based matching leads to more accurate counterfactual inference as compared to 

[11]. Next, we generalize the balancing representation based loss function [5, 6] to multiple treatment scenario. On synthetic and real-world datasets, we demonstrate that by combining matching and the generalized loss function, we outperform the state-of-the-art CI techniques such as perfect match (PM) [11] and TARNet [6]. We use precision in estimation of heterogeneous effect () and mean absolute percentage error() over the average treatment effect () as metrics.

2 MultiMBNN Causal Inference Model

In this section, we describe the preliminaries of the CI framework in the multiple treatment scenario, followed by details of the proposed MultiMBNN framework.

2.1 Causal Inference Preliminaries

We consider observational training data , comprising of samples, where each sample is given by . Each individual (also called context) is represented using covariates given by for . An individual is subjected to one of the treatments given by , where each entry of is binary, i.e., . Here, implies that the -th treatment is given. We assume that only one treatment is provided to an individual at any given point in time, and hence,

is an one-hot vector. Accordingly, the response vector for the

-th individual is given by , i.e., the outcome is a continuous random vector with entries denoted by , the response of the -th individual to the -th treatment. We define counterfactual as the alternate treatments which are unobserved for an individual.

Our goal is to train a DNN model to overcome confounding and perform counterfactual regression, i.e., to predict the response given any context and treatment. We address the issue of confounding using both, matching and learning balanced representations. In the sequel, we describe the matching method used and the loss function that caters to a multiple treatment scenario.

2.2 Generalized Propensity Score Matching

Propensity score (PS) based matching is a well-known technique used to induce the effect of randomized experiment by obtaining similar covariate distributions across treated populations [1]

. Here we employ matching based on Generalised Propensity Score (GPS), which is a more relevant score for the multiple treatment scenario. PS is the conditional probability of a given individual

receiving a treatment , i.e., . Accordingly, the GPS vector is defined as

. In practice, we do not have access to the GPS vector, and hence, we estimate it by training a predictive model. In this work, we train an SVM or random forest . We use this tuned model to predict PS,

, on the training data used for causal inference. In order to avoid overfitting, we first obtain nearest neighbors, and pick one out of these samples at random for each counterfactual treatment of .We employ the GPS vector for batch augmentation in every minibatch of SGD. For every sample within a minibatch, closest neighbour samples are obtained. For instance, consider a sample and its factual treatment . We propose the GPS-based matching strategy , which selects a neighbour with observed treatment such that and is minimum. Here, is defined as . On the other hand, the PS matching strategy [11] selects a neighbour with observed treatment such that and is minimum where

. Albeit its popularity, PS based matching has been described as inadequate and it sometimes leads to imbalance in parametric models due to model dependence 

[12]. Hence, we learn balancing representations to achieve better performance.

2.3 Learning Balancing Representations

In addition to overcoming the imbalance using matching, we also propose learning balanced representation using DNN. In [5], the authors perform counterfactual inference by generalizing the factual to counterfactual distribution, for the binary treatment scenario. We extend this framework from binary to a multiple treatment scenario, and modify the loss function as follows:



are hyperparameters that control the strength of the imbalance penalties,

is a model complexity term, and represent the distribution w.r.t. the -th treatment and the -th treatment, respectively, and is the minimum mean discrepancy measure as defined in [6]. We learn the balancing representation and the hypothesis jointly by training a deep neural network using a loss function that incorporates the factual and the imbalance error as depicted in Fig. 1. In (1), the first term on the right hand side represents the factual loss. The second term computes the pairwise minimum mean discrepancy between factual distributions of different treatments. The loss function in (1) is a generalisation of [6] to the multiple treatment scenario, it reduces to the one proposed in [6] for .

1:procedure MultiMBNN(D)
2:     Split dataset into for CI, and to compute GPS.
3:     Divide into train (), validation and test sets.
4:     Obtain GPS in , as described in Sec. 2.2.
5:     Divide into batches with each batch being
6:     for epochs and  do
7:          augment using , as described in Sec. 2.2
8:         Update , using input by minimizing Eq. 1      return ,
Algorithm 1 MultiMBNN algorithm

2.4 Proposed Approach: MultiMBNN

We propose the MultiMBNN algorithm as described in Algorithm 1. As discussed in the previous subsections, we perform batch augmentation based on GPS, and train a DNN to learn the balancing representation, and the hypothesis layers , one for each treatment, using the augmented minibatches, via SGD based training. The proposed neural network architecture is depicted in Fig. 1.

Figure 1: MultBNN: Proposed Neural network architecture

3 Experiments and Results

We illustrate the performance of MultiMBNN algorithm on synthetic (named as Syn) [13], semi-synthetic NEWS [11] and cancer genome TCGA111 datasets. We obtain samples with covariates for the synthetic dataset with. Here, accounts for the treatment assignment bias. We employ the DGP in [11] and bag-of-words context covariates to generate samples of the NEWS dataset. In the case of synthetic and NEWS datasets, we generate data for , referred to as ’name of the dataset’, followed by . The TCGA dataset consisting of samples with covariates is obtained using the DGP in [11] with (‘TCGA4’). We use (denoted as ) as defined in [11], and MAPE over [7]. We baseline the proposed algorithm using TARNet [6], MultiBNN which learns balanced representations as described in Sec. 2.3, PM is as described in [11], and MultiMBNN (), which uses (not ) along with balanced representation. We demonstrate the performance of MultiMBNN algorithm using several experimental settings. First, we illustrate the effect of treatment assignment bias using the parameter . As illustrated in Fig. 2, for Syn4 MultiMBNN performs the best for since imbalance amongst treatments leads to one of the four treatments to be suppressed resulting in a large counterfactual error and hence, an elbow point. For NEWS4,

has the least counterfactual error since imbalance amongst treatment groups is minimum, leading to near uniform distribution of population samples in all groups.

Figure 2: Comparison of CI frameworks: Counterfactual error across epochs (Syn4, ) on the left, and vs. on the right (top right: Syn4 with , bottom right: NEWS4 with ).
Figure 3: and MAPE for Syn() and NEWS() datasets for varying .
Metrics,Dataset TARNet MultiBNN PM MultiMBNN() MultiMBNN
, Syn4 10.21 0.56 9.34 0.61 8.21 0.35 7.98 0.23 7.86 0.37
MAPE, Syn4 0.07 0.02 0.08 0.03 0.06 0.02 0.04 0.01 0.02 0.02
, NEWS4 9.70 1.15 9.37 0.90 9.16 0.80 8.99 0.94 8.96 0.92
MAPE, NEWS4 0.81 0.11 0.81 0.10 0.81 0.12 0.80 0.13 0.82 0.12
, TCGA4 29.45 3.48 26.15 3.29 23.57 1.10 23.18 1.39 21.47 0.96
MAPE, TCGA4 0.93 0.14 0.84 0.10 0.92 0.07 0.78 0.03 0.80 0.08
Table 1: and MAPE for multiple runs of Syn4, NEWS4, TCGA4 (fixed ).

In Fig. 3, we illustrate the performance of the proposed algorithms and the baselines with varying for a fixed . We see that the MultiMBNN algorithm outperforms the baselines by large margins. Further, we simulate MultiMBNN with different initial seed-points maintaining and

fixed, and report the mean and standard deviation in

and for all baselines in Table 1. We infer that MultiMBNN which incorporates both matching and DNN based balancing, fairs considerably well over all the baselines.

4 Conclusions

In this work, we propose MultiMBNN algorithm which addresses the inadequacies of the matching framework by learning the balanced representations in multiple treatment causal inference scenario. We demonstrate that MultiMBNN outperforms the state-of-the-art techniques for multiple treatments and also for datasets with thousands of potential co-variate confounders. In future, we shall extend this algorithm for handling sparsity in the presence of large number of treatments.


  • [1] E. A. Stuart. Matching methods for causal inference: A review and a look forward. Statist. Sci., 25(1):1, 2010.
  • [2] S. Athey. Machine learning and causal inference for policy evaluation. In ACM SIGKDD, pages 5–6. ACM, 2015.
  • [3] L. Bottou, J. Peters, J. Quiñonero-Candela, D. Charles, D. M. Chickering, E. Portugaly, D. Ray, P. Simard, and E. Snelson. Counterfactual reasoning and learning systems: The example of computational advertising. JMLR, 14(1):3207–3260, 2013.
  • [4] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, and M. Welling. Causal effect inference with deep latent-variable models. In NeurIPS, pages 6446–6456, 2017.
  • [5] F. Johansson, U. Shalit, and D. Sontag. Learning representations for counterfactual inference. In ICML, pages 3020–3029, 2016.
  • [6] U. Shalit, F. D. Johansson, and D. Sontag. Estimating individual treatment effect: generalization bounds and algorithms. In ICML, pages 3076–3085, 2017.
  • [7] A. Sharma, G. Gupta, R. Prasad, A. Chatterjee, L. Vig, and G. Shroff. MetaCI: Meta-Learning for Causal Inference in a Heterogeneous Population. NeurIPS CausalML workshop, 2019.
  • [8] K. Imai and D. A. Van Dyk. Causal inference with general treatment regimes: Generalizing the propensity score. JASA, 99(467):854–866, 2004.
  • [9] J. Yoon, J. Jordon, and M. van der Schaar. GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets. ICLR, 2018.
  • [10] A. M. Alaa and M. van der Schaar. Bayesian inference of individualized treatment effects using multi-task gaussian processes. NeurIPS, pages 3424–3432, 2017.
  • [11] P. Schwab, L. Linhardt, and W. Karlen. Perfect match: A simple method for learning representations for counterfactual inference with neural networks. arXiv:1810.00656, 2018.
  • [12] G. King and R. Nielsen. Why propensity scores should not be used for matching. Political Analysis, 27(4):435–454, 2019.
  • [13] W. Sun, P. Wang, D. Yin, J. Yang, and Y. Chang. Causal inference via sparse additive models with application to online advertising. In AAAI, 2015.