1 Introduction
The primary goal in causal inference (CI) is to uncover the cause-effect relationship between entities, often construed as a problem of quantifying the effect of treatments on individuals. Randomised control trials, the most popular choice for obtaining causal relationships, are expensive and entail several logistical and ethical constraints. Causality is a crucial paradigm in several domains where observational data is available, such as healthcare [1], socioeconomic studies [2], advertising [3], etc. An impediment to CI in observational studies is the presence of confounding, where assignment and the response to the treatment depends on context covariates, resulting in selection bias.
In order to abate the effect of confounding, the discrepancy between the distribution of individuals receiving different treatments is minimized, in order to emulate a randomized trial. In the binary treatment case, confounding is addressed using statistical approaches such as sub-classification, weighting, imputations, and propensity score (PS) matching for unbiased per-individual causal estimates. Deep neural networks (DNN) based techniques propose counterfactual distribution modelling
[4], and learning balancing representation to minimize selection bias [5, 6, 7]. In literature, the multiple treatments scenario is interpreted as different dosage levels of a single treatment, or being one of several treatments. For the latter case, matching and sub-classification techniques have been proposed [1], and in particular, generalized propensity score (GPS) based matching was proposed [8] as an accurate metric in the multiple treatment scenario [1]. DNN-based techniques for CI include counterfactual distribution learning [9], Gaussian process based modeling[10], and PS matching based Perfect match [11]. Here, we propose a novel DNN for counterfactual inference, which overcomes confounding by leveraging on both, GPS based matching and learning balancing representations.Contributions: We propose a novel framework for counterfactual inference in the presence of confounding due to multiple treatments. We assume strong ignorability and no hidden confounding [5, 6, 7]
. We optimize DNN models via mini-batch stochastic gradient descent (SGD) to predict both, the factual and counterfactual response of a given individual to any one of the treatments. We propose the GPS based matching, along with learning balanced representations to address confounding. We estimate the GPS by training a predictive model and we use this GPS to
matchevery sample within a minibatch with its nearest neighbours. In order to learn the balanced representation, we propose a loss function using the pairwise minimum mean discrepancy (MMD) metric. The novelty of this work is two-fold. First, we show that GPS based matching leads to more accurate counterfactual inference as compared to
[11]. Next, we generalize the balancing representation based loss function [5, 6] to multiple treatment scenario. On synthetic and real-world datasets, we demonstrate that by combining matching and the generalized loss function, we outperform the state-of-the-art CI techniques such as perfect match (PM) [11] and TARNet [6]. We use precision in estimation of heterogeneous effect () and mean absolute percentage error() over the average treatment effect () as metrics.2 MultiMBNN Causal Inference Model
In this section, we describe the preliminaries of the CI framework in the multiple treatment scenario, followed by details of the proposed MultiMBNN framework.
2.1 Causal Inference Preliminaries
We consider observational training data , comprising of samples, where each sample is given by . Each individual (also called context) is represented using covariates given by for . An individual is subjected to one of the treatments given by , where each entry of is binary, i.e., . Here, implies that the -th treatment is given. We assume that only one treatment is provided to an individual at any given point in time, and hence,
is an one-hot vector. Accordingly, the response vector for the
-th individual is given by , i.e., the outcome is a continuous random vector with entries denoted by , the response of the -th individual to the -th treatment. We define counterfactual as the alternate treatments which are unobserved for an individual.Our goal is to train a DNN model to overcome confounding and perform counterfactual regression, i.e., to predict the response given any context and treatment. We address the issue of confounding using both, matching and learning balanced representations. In the sequel, we describe the matching method used and the loss function that caters to a multiple treatment scenario.
2.2 Generalized Propensity Score Matching
Propensity score (PS) based matching is a well-known technique used to induce the effect of randomized experiment by obtaining similar covariate distributions across treated populations [1]
. Here we employ matching based on Generalised Propensity Score (GPS), which is a more relevant score for the multiple treatment scenario. PS is the conditional probability of a given individual
receiving a treatment , i.e., . Accordingly, the GPS vector is defined as. In practice, we do not have access to the GPS vector, and hence, we estimate it by training a predictive model. In this work, we train an SVM or random forest . We use this tuned model to predict PS,
, on the training data used for causal inference. In order to avoid overfitting, we first obtain nearest neighbors, and pick one out of these samples at random for each counterfactual treatment of .We employ the GPS vector for batch augmentation in every minibatch of SGD. For every sample within a minibatch, closest neighbour samples are obtained. For instance, consider a sample and its factual treatment . We propose the GPS-based matching strategy , which selects a neighbour with observed treatment such that and is minimum. Here, is defined as . On the other hand, the PS matching strategy [11] selects a neighbour with observed treatment such that and is minimum where. Albeit its popularity, PS based matching has been described as inadequate and it sometimes leads to imbalance in parametric models due to model dependence
[12]. Hence, we learn balancing representations to achieve better performance.2.3 Learning Balancing Representations
In addition to overcoming the imbalance using matching, we also propose learning balanced representation using DNN. In [5], the authors perform counterfactual inference by generalizing the factual to counterfactual distribution, for the binary treatment scenario. We extend this framework from binary to a multiple treatment scenario, and modify the loss function as follows:
(1) |
where
are hyperparameters that control the strength of the imbalance penalties,
is a model complexity term, and represent the distribution w.r.t. the -th treatment and the -th treatment, respectively, and is the minimum mean discrepancy measure as defined in [6]. We learn the balancing representation and the hypothesis jointly by training a deep neural network using a loss function that incorporates the factual and the imbalance error as depicted in Fig. 1. In (1), the first term on the right hand side represents the factual loss. The second term computes the pairwise minimum mean discrepancy between factual distributions of different treatments. The loss function in (1) is a generalisation of [6] to the multiple treatment scenario, it reduces to the one proposed in [6] for .2.4 Proposed Approach: MultiMBNN
We propose the MultiMBNN algorithm as described in Algorithm 1. As discussed in the previous subsections, we perform batch augmentation based on GPS, and train a DNN to learn the balancing representation, and the hypothesis layers , one for each treatment, using the augmented minibatches, via SGD based training. The proposed neural network architecture is depicted in Fig. 1.

3 Experiments and Results
We illustrate the performance of MultiMBNN algorithm on synthetic (named as Syn) [13], semi-synthetic NEWS [11] and cancer genome TCGA111https://github.com/d909b/perfect_match datasets. We obtain samples with covariates for the synthetic dataset with. Here, accounts for the treatment assignment bias. We employ the DGP in [11] and bag-of-words context covariates to generate samples of the NEWS dataset. In the case of synthetic and NEWS datasets, we generate data for , referred to as ’name of the dataset’, followed by . The TCGA dataset consisting of samples with covariates is obtained using the DGP in [11] with (‘TCGA4’). We use (denoted as ) as defined in [11], and MAPE over [7]. We baseline the proposed algorithm using TARNet [6], MultiBNN which learns balanced representations as described in Sec. 2.3, PM is as described in [11], and MultiMBNN (), which uses (not ) along with balanced representation. We demonstrate the performance of MultiMBNN algorithm using several experimental settings. First, we illustrate the effect of treatment assignment bias using the parameter . As illustrated in Fig. 2, for Syn4 MultiMBNN performs the best for since imbalance amongst treatments leads to one of the four treatments to be suppressed resulting in a large counterfactual error and hence, an elbow point. For NEWS4,
has the least counterfactual error since imbalance amongst treatment groups is minimum, leading to near uniform distribution of population samples in all groups.
![]() |
![]() |

Metrics,Dataset | TARNet | MultiBNN | PM | MultiMBNN() | MultiMBNN |
---|---|---|---|---|---|
, Syn4 | 10.21 0.56 | 9.34 0.61 | 8.21 0.35 | 7.98 0.23 | 7.86 0.37 |
MAPE, Syn4 | 0.07 0.02 | 0.08 0.03 | 0.06 0.02 | 0.04 0.01 | 0.02 0.02 |
, NEWS4 | 9.70 1.15 | 9.37 0.90 | 9.16 0.80 | 8.99 0.94 | 8.96 0.92 |
MAPE, NEWS4 | 0.81 0.11 | 0.81 0.10 | 0.81 0.12 | 0.80 0.13 | 0.82 0.12 |
, TCGA4 | 29.45 3.48 | 26.15 3.29 | 23.57 1.10 | 23.18 1.39 | 21.47 0.96 |
MAPE, TCGA4 | 0.93 0.14 | 0.84 0.10 | 0.92 0.07 | 0.78 0.03 | 0.80 0.08 |
In Fig. 3, we illustrate the performance of the proposed algorithms and the baselines with varying for a fixed . We see that the MultiMBNN algorithm outperforms the baselines by large margins. Further, we simulate MultiMBNN with different initial seed-points maintaining and
fixed, and report the mean and standard deviation in
and for all baselines in Table 1. We infer that MultiMBNN which incorporates both matching and DNN based balancing, fairs considerably well over all the baselines.4 Conclusions
In this work, we propose MultiMBNN algorithm which addresses the inadequacies of the matching framework by learning the balanced representations in multiple treatment causal inference scenario. We demonstrate that MultiMBNN outperforms the state-of-the-art techniques for multiple treatments and also for datasets with thousands of potential co-variate confounders. In future, we shall extend this algorithm for handling sparsity in the presence of large number of treatments.
References
- [1] E. A. Stuart. Matching methods for causal inference: A review and a look forward. Statist. Sci., 25(1):1, 2010.
- [2] S. Athey. Machine learning and causal inference for policy evaluation. In ACM SIGKDD, pages 5–6. ACM, 2015.
- [3] L. Bottou, J. Peters, J. Quiñonero-Candela, D. Charles, D. M. Chickering, E. Portugaly, D. Ray, P. Simard, and E. Snelson. Counterfactual reasoning and learning systems: The example of computational advertising. JMLR, 14(1):3207–3260, 2013.
- [4] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, and M. Welling. Causal effect inference with deep latent-variable models. In NeurIPS, pages 6446–6456, 2017.
- [5] F. Johansson, U. Shalit, and D. Sontag. Learning representations for counterfactual inference. In ICML, pages 3020–3029, 2016.
- [6] U. Shalit, F. D. Johansson, and D. Sontag. Estimating individual treatment effect: generalization bounds and algorithms. In ICML, pages 3076–3085, 2017.
- [7] A. Sharma, G. Gupta, R. Prasad, A. Chatterjee, L. Vig, and G. Shroff. MetaCI: Meta-Learning for Causal Inference in a Heterogeneous Population. NeurIPS CausalML workshop, 2019.
- [8] K. Imai and D. A. Van Dyk. Causal inference with general treatment regimes: Generalizing the propensity score. JASA, 99(467):854–866, 2004.
- [9] J. Yoon, J. Jordon, and M. van der Schaar. GANITE: Estimation of Individualized Treatment Effects using Generative Adversarial Nets. ICLR, 2018.
- [10] A. M. Alaa and M. van der Schaar. Bayesian inference of individualized treatment effects using multi-task gaussian processes. NeurIPS, pages 3424–3432, 2017.
- [11] P. Schwab, L. Linhardt, and W. Karlen. Perfect match: A simple method for learning representations for counterfactual inference with neural networks. arXiv:1810.00656, 2018.
- [12] G. King and R. Nielsen. Why propensity scores should not be used for matching. Political Analysis, 27(4):435–454, 2019.
- [13] W. Sun, P. Wang, D. Yin, J. Yang, and Y. Chang. Causal inference via sparse additive models with application to online advertising. In AAAI, 2015.
Comments
There are no comments yet.