Most causal inference methods for observational data utilize so-called nuisance functions. These functions are not of primary interest but are used as inputs into estimators [Kennedy2016]. For instance, the propensity score function is the nuisance function for the inverse probability of treatment weighted estimator. In low-dimensional settings, these nuisance functions are possible to estimate nonparametrically. However, in more realistic settings, such as those that involve continuous covariates or in which the minimal sufficient adjustment set of confounders is large, parametric nuisance models are often used. Proper specification of these nuisance models is then required for the resulting estimator to be consistent. Specifically, a properly specified model contains the true function as a possible realization. Given the complex underlying relationships between variables, such parametric models are often difficult to specify correctly.
Data-adaptive supervised machine learning methods (which fit to a function using data-driven tuning parameters) [Mooney2018, Bi2019], have been suggested as an alternative approach to estimate nuisance functions in high-dimensional settings while not imposing overly restrictive parametric functional form assumptions [Schuler2017, Watkins2013, Pirracchio2015, Lee2010, Westreich2010]. Despite this optimism, issues in the use of machine learning for nuisance function estimation have become apparent [Keil2018, Chernozhukov2018]. Notably, some machine learning algorithms converge to the true answer at a slow rate, leading to substantial under-coverage of corresponding confidence intervals. The use of doubly-robust estimators allow for the use of slower converging nuisance models; but issues of smoothness and overfitting remain which may preclude the use of complex machine learning approaches.
Cross-fit estimators have been developed to reduce overfitting and impose less restrictive smoothness assumptions on machine learning algorithms used to estimate nuisance functions [Chernozhukov2018, Newey2018]. These estimators relax the convergence conditions to allow for a richer set of algorithms to be used. Cross-fit estimators share similarities with double machine learning [Chernozhukov2018], cross-validated estimators [Zheng2011], and sample splitting [Athey2015]. An extension referred to as double cross-fitting has recently been proposed as an even less restrictive approach [Newey2018]. In this work, we detail a general procedure for double cross-fit estimators and demonstrate their performance in a simulation study. We compare a wide range of estimators in the context of a simulated study of statins and subsequent atherosclerotic cardiovascular disease (ASCVD).
Data generating mechanism
Suppose we observe an independent and identically distributed sample where . Let indicate statin use, indicate the observed incidence of ASCVD, and indicate the potential value would have taken if, possibly counter to fact, an individual received treatment . The ACE is then
Note that the potential outcomes are not necessarily directly observed, so a set of conditions are needed to identify the ACE from observable data. Specifically, we assume the following conditions hold:
Counterfactual consistency [Cole2009]
Conditional exchangeability [Hernan2006]
Together, these conditions allow the ACE to be identified as:
We considered the following confounders in for the simulated data generating process: age (), low-density lipoprotein (), ASCVD risk scores (), and diabetes (). Therefore, in the following simulated data. These factors were chosen based on the 2018 primary prevention guidelines for the management of blood cholesterol [Grundy2019]. Full details on the data generating mechanism are provided in the Appendix. The incidence of statin use was chosen to be similar to reported empirical trends in US adults [Salami2017], and generated from the following model inspired by the 2018 primary prevention guidelines:
The ASCVD potential outcomes under each potential value of X were generated from the following model:
The observed outcome was calculated as . The nuisance functions considered are:
As the name implies, these nuisance functions are not of interest but are used for the estimation of the ACE. Unlike in simulated data, the correct specification of these models is often unknown; and in the context of parametric models, must be a priori specified.
Nuisance function estimators
Before estimating the ACE, the nuisance model parameters () need to be estimated. Much of the previous work in causal inference has relied on parametric regression models. However, these models must be sufficiently flexible to capture the true nuisance functions. In our simulations we consider two different parametric model specifications. First, we consider the correct model specification as described previously. This is the best-case scenario for researchers. Unfortunately, this case is unlikely to occur. Second, we considered a main-effects model, where all variables were assumed to be linearly related to the outcome and no interaction terms were included in the model. The main-effects model is quite rigid and does not contain the true density function.
As an alternative to the parametric models, we consider several data-adaptive machine learning algorithms. There are a variety of potential supervised machine learning algorithms and there is no guarantee on which algorithm will perform best in all scenarios [Wolpert1996]. Therefore, we utilize super-learner with 10-fold cross-validation to estimate the nuisance functions. Super-learner is a generalized ensemble algorithm that allows for the combination of multiple predictive algorithms into a single prediction function [van2007, Rose2013]. Super-learner has been previously shown to asymptotically perform as well as the best performing algorithm included within the super-learner procedure [van2007]
, with studies of finite sample performance indicating similar results. Within super-learner, we included the following algorithms: the empirical mean, main-effects logistic regression without regularization, generalized additive model with 4 splines and a ridge penalty of 0.6[Hastie2017]
, generalized additive model with 6 splines, random forest with 500 trees and a minimum of 20 individuals per leaf[Breiman2001]
, and a neural network with a single hidden layer consisting of four nodes. Only non-processed main-effects variables were provided to each learner.
Estimators for the ACE
After estimation of the nuisance functions, the predictions can be used in estimators for the ACE. We considered 4 estimators: g-computation, an inverse probability weighted (IPW) estimator, an augmented inverse probability weighted (AIPW) estimator, and a targeted maximum likelihood estimator (TMLE). IPW only requires nuisance function from Equation 1 and g-computation only requires the nuisance function from Equation 2. Due to their reliance on a single nuisance function, these methods are said to be singly-robust. However, these singly-robust estimators require fast convergence of nuisance models, severely limiting which algorithms can be used. AIPW and TMLE instead use both nuisance functions from both Equations 1 and 2, and have the property of double robustness; such that if either nuisance function is correctly estimated, then the point estimate will be consistent [Schuler2017, Funk2011, Lunceford2004, Bang2005]. Perhaps more important in the context of machine learning, these doubly-robust estimators allow for slower convergence of nuisance function estimators. However, all of these estimators require the nuisance models to be smooth, in the sense that the belong to the so-called Donsker class [Chernozhukov2018]. Intuitively, members of this class are less prone to overfitting than models outside the class. For models that do not belong to the Donsker class, confidence intervals may be overly narrow and result in misleading inference. Recent work has demonstrated that cross-fitting weakens the smoothness conditions for the nuisance function estimators, which allows for a more diverse set of algorithms. A double cross-fit procedure allows for further theoretical improvements [Newey2018]. Therefore, we additionally considered double cross-fit alternatives for AIPW (DC-AIPW) and TMLE (DC-TMLE). We briefly outline each estimator, with further details and formulas provided in the Appendix.
We used the g-computation procedure described by Snowden et al. 2011 to estimate the ACE [Snowden2011]. Briefly, the outcome nuisance model is fit using the observed data. From the fit outcome nuisance model, the probability of is predicted under and under for all individuals. The ACE is calculated by taking the average of the differences in the estimated
under each treatment plan. Wald-type 95% confidence intervals were generated using the standard deviation of 250 bootstrap samples with replacement, each of size.
In contrast to the g-formula, IPTW rely on estimation of the treatment nuisance models. From the predicted probabilities of , weights are constructed by taking the inverse of the predicted probabilities of the observed . These weights are then used to calculate the weighted average of among subjects with each value ofHernan2002]. Therefore, confidence interval coverage is expected to be at least 95% when the nuisance model is properly specified.
AIPW uses both the treatment and outcome nuisance model predictions to estimate the ACE. Predicted probabilities of the treatment and outcome are combined via a single equation to generate predictions under each value of , with confidence intervals calculated from the influence curve [Funk2011, Lunceford2004].
TMLE similarly uses both the treatment and outcome nuisance model predictions to construct a single estimate. Unlike AIPW, TMLE uses a ’targeting’ step that corrects the bias-variance tradeoff in the estimation [Schuler2017]. This is accomplished by fitting parametric working model, where the observed is modeled as a function of the so-called ‘clever covariate’ with the outcome nuisance model predictions included as an offset. The targeted predictions under each value of from this model are averaged, and their difference provides an estimate of the ACE. Confidence intervals are calculated from the influence curve.
A visualization of the double cross-fit procedure is provided in Figure 1. This process is compatible with both doubly-robust estimators previously described. First, the data set is split into three approximately equal-sized groups (although this can be generalized to numbers larger than three [Chernozhukov2018, Newey2018]). Note that the splits are non-overlapping, so that each subject occurs in only 1 split. Second, treatment and outcome nuisance model parameters are estimated in individual sample splits. This involves using the super-learner fitting procedure independently for each split (for a total of six times). Third, predicted treatment probabilities and expected values for outcomes are calculated from the nuisance models in the discordant splits, such that the predictions from super-learner do not come from the data used to fit the algorithms. For example, sample split 1 has the probability of treatment predicted with the treatment nuisance model from split 3; and the expected value of the outcome predicted with the outcome nuisance model from split 2. For each of the splits, the estimator (e.g. AIPW, TMLE, etc.) is calculated from the treatment and outcome predictions. In a final step, the estimates from all splits are averaged together to produce a final point estimate.
Since the ACE is dependent on a particular splitting of the input data, the procedure is repeated a number of times with different possible splits. In our implementation, we used 100 different splits as is recommended in other cross-fitting procedures [Chernozhukov2018]. Results were unstable when only using 10 splits (results not shown). The overall point estimate of the ACE is calculated as the median of the ACE for all splits
While the mean can also be chosen, it is more susceptible to outliers and may require a larger number of different splits[Chernozhukov2018]. The estimated variance for the ACE consists of two parts: the variability of the ACE within a particular split and the variance of the ACE point estimate between each split. The variance for the s different splits, is the median of
To compare the performance of estimators for each of the 18 different combinations, we calculated bias, empirical standard error (ESE), root mean squared error (RMSE), average standard error (ASE), average confidence limit difference (CLD), and 95% confidence limit coverage over 2000 simulated samples. Bias was defined as mean of the estimated ACE from each simulation minus the true population ACE. ESE was calculated as the standard error of the estimate from each simulation. RMSE was defined the square root of bias squared plus ESE squared. ASE was the mean of the estimated standard error from each simulation. CLD was the mean of the upper confidence limit minus the lower confidence limit. 95% confidence interval coverage was calculated as the proportion of confidence intervals containing the true population ACE.
All simulations were conducted using Python 3.5.1 (Python Software Foundation, Wilmington, DE) with the sklearn implementations of the previously described algorithms [Pedregosa2011]. Outside of the specified parameters, the defaults of the software were used. The true ACE was calculated as the average difference in potential outcomes for a population of 10,000,000 individuals. Simulations for estimation of the ACE were repeated 2000 times for n=3000. The sample size for simulations was chosen such that when split into three equally sized groups (n=1000), the true parametric model could be fit and correctly estimate the true ACE.
Before presentation of the simulations, we present estimators in the context of a single data set. Characteristics of the single study sample are displayed in Table 1. Results for the estimators are presented in Table 2. Nuisance models estimated with machine learning led to substantially narrower confidence intervals as indicated by the CLD. Differences were less stark for the double cross-fit estimators. Broadly, run-times for estimators of ACE were short. The double cross-fit estimators had substantially longer run times due to the repeated sample splitting procedure. As reference, a single estimation of the DC-AIPW with super-learner required 660 different predictive models to be fit.
|Statin (n=776)||No statin (n=2224)|
|Age (SD)||58 (9.5)||53 (7.6)|
|log(LDL) (SD)||4.92 (0.2)||4.86 (0.2)|
|Risk score (SD)||0.15 (0.2)||0.06 (0.1)|
Descriptive statistics for a single sample from mthe data generating mechanism. Continuous variables are preseted as mean (standard deviation)
SD: standard deviation, LDL: low-density lipoproteins, ASCVD: atherosclerotic cardiovascular disease.
|Machine learning||-0.09||0.015||-0.12, -0.06||0.06||82.3|
|Machine learning||-0.11||-0.028||-0.16, -0.05||0.11||0.3|
|Machine learning||-0.11||0.016||-0.14, -0.08||0.06||0.7|
|Machine learning||-0.12||0.016||-0.15, -0.09||0.06||0.7|
|Machine learning||-0.11||0.022||-0.16, -0.07||0.09||128.1|
|Machine learning||-0.11||0.021||-0.15, -0.07||0.08||129.9|
RD: risk difference, SD(RD): standard deviation for the risk difference, 95% CL: 95% confidence limits, CLD: confidence limit difference defined as the upper confidence limit minus the lower confidence limit.
Machine learning estimators were super-learner with 10-fold cross validation. Algorithms included were the empirical mean, main-effects logistic regression without regularization, generalized additive model with 4 splines and a ridge penalty of 0.6, generalized additive model with 6 splines, random forest with 500 trees and a minimum of 20 individuals per leaf, and a neural network with a single hidden layer consisting of four nodes.
Double cross-fit procedures included 100 different sample splits.
* Run times are based on a server running on a single 2.5 GHz processor with 5 GB of memory allotted. Run times are indicated in minutes. G-formula run-times are large due to the use of a bootstrap procedure to calculate the variance for the risk difference. IPTW used robust variance estimators. AIPW, TMLE, DC-AIPW, and DC-TMLE variances were calculated using influence curves.
As expected, the performance of correctly specified parametric nuisance models for all ACE estimators were unbiased and confidence intervals resulted in near-95% coverage (Figure 2, Table 3). The most efficient estimator was g-formula (ESE=0.017), followed by TMLE (ESE=0.021), AIPW (ESE=0.021), and IPTW (ESE=0.023). DC-TMLE and DC-AIPW were comparable to their non cross-fit counterparts (0.021 and 0.021, respectively). Confidence interval coverage was higher for double cross-fit estimators.
For main-effects parametric nuisance models, all ACE estimators were biased from the true target parameter. Increased RMSE was primarily a result of the occurrence of bias. The double cross-fit procedure did not improve estimates in terms of bias due to model misspecification. Confidence interval coverage is likely greater solely due to the penalty in estimated variance due to variation between sample splits.
). Non-cross-fit doubly robust estimators with machine learning resulted in unbiased estimates of the ACE, but confidence interval coverage was below expected levels for AIPW (91.1%) and TMLE (89.5%). Confidence interval coverage of DC-AIPW and DC-TMLE were near nominal levels (95.6% and 95.0%, respectively).
RMSE: root mean squared error, ASE: average standard error, ESE: empirical standard error, CLD: confidence limit difference, Coverage: 95% confidence limit coverage of the true value.
IPTW: inverse probability of treatment weights, AIPW: augmented inverse probability of treatment weights, TMLE: targeted maximum likelihood estimator, DC-AIPW: double cross-fit AIPW, DC-TMLE: double cross-fit TMLE.
True: correct model specification. Main-effects: all variables were assumed to be linearly related to the outcome and no interaction terms were included in the model. Machine learning: super-learner with 10-fold cross-validation including empirical mean, main-effects logistic regression without regularization, generalized additive models, random forest, and a neural network.
Under the unlikely scenario in which parametric nuisance model specifications correctly capture the true density, the ACE estimates are unbiased and subsequent inference is valid for all estimators considered. In these scenarios, confidence intervals were wider for double cross-fit estimators due to the variance between splits being incorporated from the sample splitting procedure. This increase in variance highlights the bias-variance trade-off made when choosing a less-restrictive estimator. However, it is often unreasonable to assume correct parametric model specification in high-dimensional data with weak background information or theory. In these scenarios, the pursuit of weaker parametric assumptions for nuisance model specification is worthwhile, with machine learning being a viable approach. However, naïve use of machine learning may lead to bias and incorrect inference. As highlighted in our simulation, doubly robust estimators with double cross-fit and machine learning outperformed both estimators with incorrectly specified parametric nuisance models and non-cross-fit estimators with machine learning. As such, double cross-fit estimators with machine learning may be the preferred approach for ACE estimation in most epidemiologic studies.
Cross-fitting has had a long history in statistics [Bickel1988, Pfanzagl1990, Hajek1962], but recent emphasis has focused on its use for nonparametric nuisance function estimation [Chernozhukov2018, Newey2018, Robins2008, Bickel1982]. Broadly, double cross-fit procedures can be seen as an approach to avoid the overfitting of nuisance models. Single cross-fit procedures uncouple the nuisance model estimation from the corresponding predicted values, preventing so-called own observation bias [Newey2018]. However, the treatment nuisance model and outcome nuisance model are estimated using the same data in single cross-fit procedures. Double cross-fit procedures decouple these nuisance models by using separate splits, removing so-called nonlinearity bias [Newey2018]. As demonstrated in the simulations, reductions of these bias terms results in tangible benefits regarding point estimation and inference with machine learning algorithms.
While cross-fitting has tangible benefits, these benefits are not without cost. First, run-times for the double cross-fit estimators are substantially longer due to the repetition of fitting algorithms to a variety of different splits. We note that the double cross-fit procedure can easily be made to run in parallel, substantially reducing run-times. Computational costs may limit cross-fit procedures to estimators with closed-form variances, since bootstrapping would require considerable computational resources. A second, and potentially more problematic, cost is that sample splitting procedures reduce the amount of data available with which to fit algorithms. Because of the reduction in available data in each sample split, some complex algorithms may be precluded from use. The problem of finite data is furthered with the use of k-fold super-learner, further stretching the available data. For small data sets, increasing the number of folds in super-learner may aid in alleviating this issue [Naimi2018]. For cross-fitting, single cross-fit procedures may be used. Instead of splitting into three samples, a single cross-fit at minimum requires splitting the data in half. However, the machine learning advantage of flexibility may be limited in these small data sets to begin with [Keil2018]. Whether single cross-fit with machine learning or highly flexible parametric models is preferred in these scenarios is an area for future work.
The problems of sample splitting can manifest themselves as random violations of the positivity assumption [Petersen2012]. As detailed in previous work by Yu et al. 2019, confounders that are strongly related to the exposure may result in positivity violations [Yu2019]. Due to the flexibility of machine learning algorithms, these positivity issues may result in highly variable estimates. Furthermore, positivity issues may not be easy to diagnose, especially in procedures like double cross-fitting. Similar to previous recommendations [Yu2019], using multiple approaches to triangulate estimates may be helpful. For example, researchers may want to compare a flexible parametric AIPW and a double cross-fit AIPW with super-learner.
While our results support the use of machine learning algorithms, machine learning is not panacea for causal inference. Rather, machine learning can be seen as weakening a single assumption, namely the assumption of proper model specification. Prior substantive knowledge to justify counterfactual consistency, conditional exchangeability, and positivity remain necessary for causal inference [Keil2018, Naimi2018]. For super learner and other ensemble approaches to provide the maximal benefit in terms of specification, a diverse set of algorithms should be included [Rose2013]
. Furthermore, multiple tuning parameters, sometimes referred to as hyperparameters, should be included. While the program defaults are often used, these hyperparameters can dramatically change performance of algorithms[Dominici2002]. Therefore, super learner should not only include a diverse set of algorithms, but also those same algorithms under a diverse set of hyperparameters. Our simulations did not extensively explore hyperparameters; with the inclusion of only two hyperparameter specifications for generalized additive models. Because double cross-fit procedures scale poorly in terms of run-time with the addition of algorithms, including more algorithms with different hyperparameters can have substantial cost in terms of run-time. Depending on the complexity of machine learning algorithms used, alternative approaches may be required for hyper-parameter tuning within the cross-fitting procedure [Wong2019]. Despite these concerns, a wide variety of hyperparameters should be explored in applications of double cross-fitting. Lastly, data processing, such as calculation of interaction terms between variables, may be necessary for performance and should be done in practice [Naimi2017].
Future work should include diagnostics for double cross-fitting and the addition of other nuisance functions. Due to the repetition of sample splitting, standard diagnostics (e.g. examining the distributions of predicted treatment probabilities [Yu2019]) may more difficult to interpret. Often, realistic analyses further have additional issues that must be addressed, such as missing data and loss-to-follow-up. Therefore, additional nuisance functions (like inverse probability weights for informative loss-to-follow-up) are often needed and cross-fit procedures for these scenarios should be assessed.
Machine learning is not a panacea for the monumental task of causal inference. However, these algorithms do impose less restrictive assumptions regarding the possible forms of the nuisance functions used for estimation. Double cross-fit estimators should be seen as an approach to flexibly fitting nuisance models while retaining valid inference. In practice, double cross-fit estimators should be used regularly with a super-learner that includes a diverse library of learners.
The authors would like to thank Ashley Naimi, Edward Kennedy, and Stephen Cole for their advice and discussion. We would like to thank the University of North Carolina at Chapel Hill and the Research Computing group for providing computational resources that have contributed to these results. PNZ received training support (T32-HD091058, PI: Aiello, Hummer) from the National Institutes of Health.
Section 1: Data generating mechanism
Let indicate age, indicate the natural-log transformed low-density lipo-protein, indicate diabetes, indicate frailty, and indicate the risk score. All observations are indexed by , such that refers to the age of individual . All variables were generated from the following distributions in their respective order.
where . Statin () and atherosclerotic cardiovascular disease () were generated from the following models
From each of the potential outcomes, the observed outcomes were calculated based on the potential outcome under the observed treatment.
Across all simulations the estimand of interest, the population average causal effect, was defined as
The true value for was calculated directly from the potential outcomes of 10,000,000 individuals. Let indicate the sample-specific estimate of the population average causal effect.
Section 2: Estimators
We consider two nuisance models, where is the outcome nuisance model and is the treatment nuisance model. The estimated outcome nuisance model is expressed as:
and the estimated treatment nuisance model is expressed as:
The g-computation algorithm consists of the outcome nuisance model predictions only. The estimated parameters of the outcome nuisance model are used to predict outcome values. The average of the predicted outcome values under each treatment plan are contrasted. The average causal effect is estimated via:
Wald-type confidence intervals were calculated with the standard error estimated using a bootstrapping procedure. Briefly, samples are drawn with replacement from the observed sample. Nuisance models are fit to the newly generated sample and used to calculate the average causal effect. This procedure was repeated 250 times, with the bootstrap standard error defined as the standard error of the 250 re-estimated average causal effects.
Inverse probability weighed estimator
IPW consists of the treatment nuisance model predictions only. The estimated parameters of the treatment nuisance model are used to predict probabilities of treatment. The inverse of these estimated probabilities are used to construct the inverse probability weights. The average causal effect is estimated via:
Confidence intervals can similarly be calculated using a bootstrapping procedure. Instead, we used a robust variance estimator, which ignores the estimation of the nuisance model. Confidence intervals constructed from the robust variance result in conservative variance estimates for the ACE, with coverage expected to be at least 95% when the nuisance model is properly specified.
Augmented inverse probability weights
AIPW consists of the treatment and outcome nuisance model predictions. These predictions are combined via the following formula to estimate the average causal effect:
As indicated by their name, the AIPW estimator can be seen to include the IPW and an augmentation term. For inference, Wald-type confidence intervals were constructed from the following variance empirical influence curve-based estimator
This estimated variance is directly used to calculate Wald-type confidence intervals via
Targeted maximum likelihood estimation
TMLE similar consists of the treatment and outcome nuisance models. For ease of later notation, we define the so-called clever covariate as:
Using the clever covariate, we target (or update) the predictions from the outcome nuisance model. The targeting step is done by first estimating in the following parametric working model
The estimate is then used to update the untargeted estimates via the following formulas:
Therefore, we can use the targeted model predictions to calculate the average causal effect as:
For inference, Wald-type confidence intervals were constructed from the following variance estimator: