Matching methods are pervasively used throughout the social and health sciences to make causal conclusions where access to randomized control trial is scarce, but when observational data are widely available. Matching methods find groups of similar individuals, some of whom select into treatment and some of whom select into control. As such, matching methods are interpretable in that they allow fine-grained troubleshooting of the data; for instance, examining the matched groups may allow the user to locate unmeasured confounders that lead units to have a higher chance of being treated. Matching is also powerful in that it can allow the user to correctly estimate highly nonlinear treatment effects – but this is only true when the matches are of high quality. The quality of the matches is our main consideration in this work.
Typically, matching methods place units into the same matched group when they are close together, where closeness is measured in terms of a pre-defined distance (e.g., exact, coarsened exact, Euclidean, etc.), while maintaining balance constraints between treatment and control units. Despite its merits, this classical paradigm has serious flaws, namely that it relies heavily on a properly specified distance metric. The distance metric cannot be determined without an understanding of the importance of the variables; for instance, the quality of any prespecified distance metric that weighs all covariates equally will degrade as the number of irrelevant covariates increases. This is true no matter which matching methodology is used as the distance metric calculations could be essentially determined by these irrelevant covariates. This has previously been referred to as the toenail problem [wang2017flame], where the weighing of irrelevant covariates (like toenail length) can overwhelm the metric for matching. A separate concern is that the covariates may be scaled differently, where a given distance along one covariate has a different impact than the same distance along a different covariate; in this case, the total distance metric can easily be determined by less relevant covariates, again leading to lower quality matched groups.
Ideally, the distance metric would stretch to capture important covariates so that after matching, estimates of treatment effects within the matched groups would be accurate. If the researcher knows how to choose the distance metric so that it leads to accurate treatment effect estimates, this would solve the problem. However there is no reason to believe that this is achievable in high-dimensional settings. Producing high dimensional functions to characterize data is a task for which humans alone are not naturally adept.
In this work, we propose a framework for matching where an interpretable distance measure between matched units is learned from a held-out training set. As long as the distance metric generalizes from the training set to the full sample, we are able to compute high-quality matches and accurate estimates of treatment effects (conditional average treatment effects – CATEs) within the matched groups. One can use any form of distance metric to train, and in this work we focus on exact matching for discrete variables and generalized Mahalanobis distances for continuous variables. By definition, the generalized Mahalanobis distance is determined by a matrix. If the matrix is diagonal, the distance calculation is a stretch for each covariate. Irrelevant covariates will be compressed so that their values are effectively always zero. Highly relevant covariates will be stretched, so that to be considered a match, the values of that covariate must be very similar. In this way, diagonal matrices lead to very interpretable distance metrics. If the Mahalanobis distance matrix is not constrained to be diagonal, then it induces a stretch and rotation, leading to more flexible but less interpretable notions of distance.
The new framework is called Learning-to-Match, and the algorithm introduced in this work is called Matching After Learning to Stretch (MALTS). We tested MALTS against several other matching methods in simulation studies, where ground truth is known, and here, MALTS achieves substantially and consistently better results than other methods including Genmatch, propensity score matching, standard Mahalanobis distance and coarsened exact matching. Even though our method is heavily constrained to produce interpretable matches, it performs at the same level as non-matching methods that are designed to fit extremely flexible but uninterpretable models directly to the response surface.
2 Related work
Since the 1970’s, the causal inference literature on matching methods has concentrated on dimension reduction techniques [e.g., rubin1973matching, rubin1973use, rubin1976multivariate, cochran1973controlling]
. The leading dimension reduction approach in this literature uses the propensity score, the conditional probability of treatment given covariate information. Propensity score methods target average treatment effects and so do not produce exact matches or almost-exact matches. When treatment is binary, they project data onto one dimension, and closeness of units in propensity score does not imply their closeness in covariate space. As a result, the matches cannot directly be used for estimating heterogeneous treatment effects. Regression methods can be used for CATE estimation, but this assumes that the regression method is correctly specified – or in the case of doubly robust estimation[e.g., farrell2015robust]
either the propensity model or the outcome model needs to be correctly specified. In practice neither is likely to be correctly specified. Machine learning approaches generalize regression approaches and can create models that are extremely flexible and predict outcomes accurately for both treatment and control groups[hahn2017bayesian, hill2011bayesian, chernozhukov2016double]. However this loses the interpretability inherent to almost-exact matches. In practice, MALTS performs similarly to (or better than) several machine learning methods in our experiments, despite being restricted to interpretable almost-exact matches.
A flexible setup for producing high quality matches is the optimal matching literature [rosenbaum2016imposing], which started with network flow algorithms and has evolved to use integer programming to produce matches that are constrained in user-defined ways [zubizarreta2012using, zubizarreta2014matching, keele2014optimal, resa2016evaluation, AlamRu15, AlamRu15nonparam, pmlr-v54-kallus17a]. In all of these approaches, the user defines the distance metric rather than learning it through data, which is time consuming and likely inaccurate, leading potentially to poor quality of the matched groups. Further, integer programming methods do not scale to problems of the sizes we consider in this manuscript.
An alternative to optimal matching is coarsened exact matching [iacus2012causal], an approach that requires users to specify explicit bins for all covariates on which to construct almost-exact matches. This requires users to know in advance which high-dimensional bins the outcome will be insensitive to, which is essentially equivalent to requiring the user to know the answer to the problem we investigate in this work. Large amounts of user choice to define these bins can also lead to unintentional user bias. By training the stretching rather than asking the user to define it as in CEM, this bias is reduced.
3 Learning-to-Match Framework
Throughout we will write for the treated potential outcome of an individual with observed pre-treatment covariates and for the control potential outcome. When notationally obvious we simplify these to and . We assume that there are no unobserved confounders and that the stanadrd ignorability assumptions hold [Rubin2005]. For each individual we define their conditional treatment effect as . In our framework, a learning-to-match algorithm consist of three modeling decisions: the form of distance metric used for matching, the method of learning parameters of that distance metric, and the method of matching. A training set is used to train the parameters of the distance metric, and that learned distance metric is used on the rest of the sample in the test-set for matching and CATE prediction.
Formally, a distance function is a map that is symmetric, positive definite, and obeys the triangle inequality, where covariates are in . Our goal is to minimize the expected loss between estimated treatment effects and true treatment effects across a target population (this can either be a finite or super-population). We use hat notation for estimates. Let the expected loss be:
We use squared loss, . If we had a random i.i.d sample from the marginal distribution of the covariates , and if we had labels and we could then approximate the expected loss by the average loss over the sample:
however, the difficulty in causal inference is that we only observe treatment outcomes or control outcomes for an individual , so we cannot directly calculate the treatment effect. Breaking the sum into treatment and control group:
If the target population is the treatment group (for estimating conditional average treatment effects on the treated – CATT), then the sum over the control group would be removed from the quantity above. Consider one term from the treatment group. We know and thus do not estimate it, so the term becomes . We still do not know , and must estimate it. If we use matching, we will compute the estimated control outcome by an average of the control outcomes within its matched group that we can observe. Let us define the matched group for unit in terms of the observed control units indexed by ,
We allow overlap of matched groups in this notation, thus these are bounded-distance nearest-neighbor matches, using the specified distance metric. Thus,
where is the size of the matched group .
In order for matching to yield a high quality estimate , it is sufficient (but not necessary) for the distance function to have the following property:
Smoothness-of-Outcomes-for-Distance Property (SOD). A distance measure has the SOD property if for any observed values of and :
In other words, according to the SOD condition, if we consider all units within a small distance of , the value of the outcome is almost constant. If this were true for all units, then we automatically have
where both summations are over . Thus, the SOD condition would suffice to provide us with high quality estimates for counterfactuals, and therefore, low losses on estimates of conditional average treatment effects.
Our framework learns a distance metric from a separate training set of data (not the test samples considered in the averages above), so that the distance approximately has the SOD property. Denote this training set by , . To learn so that its distance has an approximate SOD property, we minimize the following:
where is defined by Equations (1) and (2) including its dependence on the , using the training data for creating matched groups, and is defined analogously. More generally we can write the step above as:
4 Matching After Learning to Stretch
MALTS is an almost exact matching method for causal inference and treatment estimation designed to work on both experimental and observational datasets. MALTS performs treatment effect estimation using following three stages: 1) distance metric learning, 2) matching samples, and 3) estimating CATEs.
Let us start with nearest neighbor matching using a learned weight function that depends on the distances between points. The weights for the weighted nearest neighbors matching can be learned via this specification of Eq (4):
We let be a function of the distances . For example, the can be binary to encode whether belongs to ’s -nearest neighbors. Alternatively they can encode soft KNN weights where . In the remainder of the paper we consider distance functions for that are parameterized by a set of parameters .
Euclidean distance in covariate space works best for continuous covariates. As such, one might consider distances of the form where codes the orientation of the data. An example of this in the causal inference literature is the classical Mahalanobis distance where is hardcoded as the inverse covariance matrix for the observed covariates. This approach has been demonstrated to perform well in settings where all covariates are observed and the inferential target is the average treatment effect [stuart2010]. We are also interested in individualized treatment effects and just as the choice of Euclidean norm in Mahalanobis distance matching depends on the estimand of interest, the stretch metric needs to be amended for this new estimand. We propose learning directly from the observed data rather than setting it beforehand; this is because we do not know the true distance metric, and humans are not good at creating high dimensional functions in our heads from data. The distance metric can be learned such that the objective with respect to is minimized.
For discrete (categorical/unordered) data, Euclidean distance is not a natural distance metric – mixed (ordered+unordered) data poses a different set of challenges than either one alone, given the geometry of the space. Approximate closeness is ill-defined for unordered discrete covariates while exact matching is not usually possible on continuous covariates as it is unlikely for two points to have exactly the same values (assuming a continuous distribution for continuous covariates, a match would occur with almost zero probability). If we use the same distance metric for both discrete and continuous data, then units that are close in ordinal space might be arbitrarily far in categorical space. Because of this, it is not natural to parameterize a single form of distance metric to enforce both exact matching on categorical data and almost exact matching for ordinal data. While Mahalanobis-distance-matching papers recommend converting unordered categorical variables to binary indicators, this approach does not scale and in fact can introduce an overwhelming number of irrelevant covariates.
We propose parameterizing our distance metric in terms of two components: one is a learned weighted Euclidean distance for continuous (or generally ordered) covariates while the other is learned weighted hamming distance for discrete (or generally unordered) covariates. These components are separately parameterized by matrices and respectively. Let denote continuous covariates while denote discrete ones. The distance metric we propose to work with is given by:
where is the indicator that event occurred. We thus perform almost-exact matching on the unordered covariates and learned-Mahalanobis-distance matching on the rest of the covariates. We thus learn such that we have approximately solved:
where is defined as the set of nearest points to using the metric parameterized by . For interpretability, we let be a diagonal matrix, which allows stretches or reflections of the continuous covariates. This way, the magnitude of an entry in or provides the relative importance of the indicated covariate for the causal inference problem.
For simplicity in the formulation above, we assumed additive separability of ordered and unordered parts. If desired, separate stretch metrics can be learned for each discrete “bin;” a natural extension (not explored here) is to use Bayesian shrinkage to pull the stretch matrices together for the different bins. We use python scipy library’s implementation of COBYLA, a non-gradient optimization method to learn the optimal [scipy, Powell1994].
After learning the distance matrix on the training data, we used the new learned distance metric to predict conditional average treatment effects (CATEs) for each unit in the test-set, using nearest neighbors from the training set. For any given new point in the -dimensional covariate space we construct a matched group using control set via , and using treatment set via . Estimated CATE for a point is calculated via KNN for treatment and control:
If we were to abandon the goal of an interpretable distance function, the framework trivially generalizes to incorporate complex data structures by introducing a neural network (or other flexible learning) framework for coding the data. That is, we can redefine the distance metric via
is a summary of relevant data features that is learned using a neural network or other complex modeling framework. As deep neural networks show improvements over other methods mainly for certain problems where latent spaces need to be constructed (computer vision, speech), we expect that the stretch/almost-exact match combination should suffice for most datasets.
In this section, we discuss the the performance of MALTS on both synthetically generated datasets (continuous covariates and mixed covariates) and real world datasets.
5.1 Continuous Covariates
A variety of experimental settings were designed where the covariates were sampled from a continuous distribution. We analyzed MALTS by studying error rates in prediction of CATEs, the matched groups produced by MALTS, and the change in error rate as the number of covariates and number of units in the dataset changes.
The data generation process used for testing MALTS on continuous covariates includes quadratic treatment effect terms in addition to a linear treatment effect and linear baseline effect, as follows. (For each data generation process, we generate covariates, with of them contributing to the outcome i.e. there are irrelevant covariates):
Data generative process-1 (with independent covariates):
Data generative process-2 (with correlated covariates):
Let us now discuss the results on these data generative processes.
5.1.1 Variance within the Matched groups
We generated different training sets where and . We tested the method on different test datasets with . Recall that during the training phase, MALTS learns the distance metric in a way that stretches the more relevant covariates while compressing the irrelevant covariates in order to better predict the outcome. Because the
’s of the true model are exponentially decreasing in absolute value, MALTS should ideally learn a distance metric with the largest stretch for the first covariate, second largest stretch for the second one, and so on. This ensures that the units within a matched group will be closest on the covariates that affect the outcome the most. A natural measure of covariate balance in the test set is the variability of a covariate within a matched group. After running MALTS algorithm, we expect that the average variance of the first covariate dimension within matched groups is the lowest, followed by the second, followed by the third, until prediction uncertainty overwhelms the (diminishing) importance of covariates. For this collection of datasets, Figure1 plots the average variances for each covariate. We note that the variance stops increasing beyond the fourth covariate—that is the matching mechanism does not necessarily distinguish between the importance of covariates above index four. As the contribution of these covariates to the outcome is less than 10% of the magnitude of the first covariate this agrees with out intuition, based on the knowledge of data generation process.
5.1.2 Number of covariates and size of training set
In this simulation we consider the performance of our distance-metric-learning task as a function of the size of the training set and number of covariates. Using the same simulation setup as above, we vary and to study this behavior. In this simulation, the ATE is given by and Figure 2 plots the average CATE absolute error for different training set sizes. For a fixed number of covariates we observe that increasing the size of the training set from 200 samples to 5000 samples reduces the absolute error in estimating the CATEs by nearly 50%. Increasing the number of covariates makes the inferrential task substantially more complex thus leading to an increase in absolute error.
5.1.3 Error-rate analysis
In this simulation we compare MALTS with several other methods: BART [Chipman10bart:bayesian], CRF [athey2015exact]
, difference of random forests, GenMatch[genmatch] and Propensity Score Nearest Neighbor matching [ross2015propensity]. We study two different settings: one with uncorrelated and one with correlated covariates.
: The training set contains control units and treated units. There are total covariates observed, but only associated with the outcome.
: The training set contains units with total covariates observed but only associated with the outcome.
In both settings, the test set has units.
Figure 3 compares the errors in estimating CATEs in both settings using the different methods. We note that the other matching methods are not designed for CATE estimation and so perform poorly. On the other hand, we see that MALTS performs on the order of the best modeling method, BART. Figure 4(a) is based on the uncorrelated simulated data and plots the reciprocal of the diameter of each matched group (where diameter is defined as the maximum distance of matched samples to the query sample in a matched group) versus the absolute CATE error. We note that tighter groups are of higher quality and lead to better estimation of CATEs. This suggests that we can threshold at a chosen diameter value to remove bad quality matched groups or we can weight the matched group as a function of diameter for estimating any quantity of interest.
Figures 5 and 6 provide a detailed view of the performance of the different methods. MALTS outperforms other matching methods and is on par with modeling methods like causal forest and difference of random forests. However, MALTS doesn’t outperform BART, but recall, MALTS distance metric was not chosen to be particularly flexible in order to maintain interpretability of the distance metric. We have similar findings when when covariates are correlated.
5.2 Mixed Covariates
We test MALTS on mixed data using the following data generative process ( represents the continuous part of the covariates and represents the discrete part of the covariates for sample in the dataset with samples. is a
dimensional vector withcontinuous covariates of which are important ones, and discrete covariates of which are important ones in determination of outcome ):
We consider continuous covariates where are associated with the outcome. We also consider discrete covariates where are associated with the outcome. Letting represent the 9 important covariates and letting we generate the outcomes according to the following quadratic model:
Figure 7 provides a summary of the errors in estimating CATEs across the different methods. We note that MALTS continues to perform as well as methods that directly model the outcome while outperforming the more interpretable matching methods. Figure 8 provides an in depth view of these approaches. In particular, we note that MALTS is better able to adapt to the unimportant discrete covariates than both CRF and the different of two random forests.
5.3 Real Dataset: Lalonde
In this section we study the performance of MALTS on the classical Lalonde dataset. The data describe the National Support Work Demonstration (NSW) temporary employment program and its effect on income level of the participants [lalonde]
. This dataset is frequently used as a benchmark for the performance of methods for observational causal inference. We employ the male sub-sample from the NSW in our analysis as well as the PSID-2 control sample (of male household-heads under age 55 who did not classify themselves as retired in 1975 and who were not working when surveyed in the spring of 1976)[dehejia_wahba_nonexp]. The outcome variable for both experimental and observational analyses is earnings in 1978 and the considered variables are age, education, whether a respondent is black, is Hispanic, is married, has a degree, and their earnings in 1975 [website_dehejia]. Previously, it has been demonstrated that almost any adjustment during the analysis of the experimental and observational variants of these data (both by modeling the outcome and by modeling the treatment variable) can lead to extreme bias in the estimate of average treatment effects. Tables 1 and 2 present the average treatment effect estimates based on MALTS and state of the art modeling and matching methods. We note that MALTS (after appropriately weighting for high diameter matched group according to the procedure described in Section 5.1.3 with weight for a matched group equal to ) is able to achieve accurate ATE estimation based on both experimental and observational datasets.
As MALTS is a method for constructing interpretable matched groups, we present sample output matched groups of MALTS for the observational Lalonde dataset in Table 3. At the top of the table we present the learned stretches for the distance metric () and note that matching on age appears to be extremely important, followed by education and income in 1975. We present two individuals for whom we want to construct matched groups: Query 1 is an 23 year old individual with no income in 1975. We are able to construct a tight matched group for this individual (both in control and in treatment). Query 2 is a 19-year-old high income individual, which is an extremely unlikely scenario, leading to a matched group with a very large diameter, which should probably not be used during analysis.
|Method||ATE Estimate||Estimation Bias (%)||Number of units matched|
|Method||ATE Estimate||Estimation Bias (%)||Number of units matched|
6 Conclusion and Discussion
This paper introduced MALTS, a new method for causal inference and treatment effect estimation. MALTS learns a metric on the covariate space that puts more weight on important covariates and so produces high quality matched groups for analysis. MALTS is able to deal with irrelevant covariates by downweighting their importance in the weighted nearest neighbor algorithm that produces the matches. Unlike other competitive black-box methods including BART, Causal Forest or difference of two random forests, MALTS produces interpretable matched groups and also returns the stretch matrix highlighting relevant covariates. A natural extension to the introduced MALTS framework is to use neural networks or support vector machines to learn a flexible distance metric in a latent space, thus allowing us to match on images and text documents. This setup can be extended to studying treatment effect in texts or images. The MALTS framework can further be extended to deal with missing covariates, and can be adapted to instrumental variables, which is an ongoing effort.