Predicting unknown outcomes based on a model estimated on a training data set is a common machine learning problem. Many machine learning algorithms have been proposed and shown to be very successful for prediction when the test data have the same distribution as the training data or the model specification is correct. In real applications, however, we rarely know the underlying true model for prediction, and we cannot guarantee the unknown test data will have the same distribution as the training data. For example, different geographies, schools, or hospitals may draw from different demographics, and the correlation structure among demographics may also vary (e.g. one ethnic group may be more or less disadvantaged in different geographies). If the model is misspecified, it may exploit subtle statistical relationships among features present in the training data to improve prediction, resulting in inaccuracy of parameter estimation and instability of prediction across test data sets with different distributions.
Recently, some algorithms have been proposed to address the agnostic distribution shift from unknown test data, including domain generalization , causal transfer learning  and invariant causal prediction 
etc. Their motivation is to explore the invariant structure between predictors and the response variable across multiple training datasets for prediction. But they cannot well handle the case of distribution shifts that are not observed in the training data. Moreover, they do not consider the interaction of distribution shift and model misspecification. Recently, some papers[7, 14] were proposed to address stable prediction problem using methods drawn from the literature on causal inference, achieving improved performance. But they did not consider the model misspecification and their algorithms were restricted to the predictive setting with binary predictors and binary response variable.
In this paper, we focus on the problem of stable prediction with model misspecification and agnostic distribution shift, where we assume that all features (predictors) fall into one of two categories: one category includes the stable features , which have causal effects on outcome that are invariant across environments (e.g., across training and test sets). The other category includes the unstable features , which have no causal effects on outcome, but may be correlated with either stable features, the outcome, or both. The correlation may be different in different environments. Under the assumption that all stable features are observed, model misspecification would be induced by some omitted nonlinear or interaction terms of stable features (i.e., or
). Since different environments (e.g., training and test sets) have different covariate distributions, the parameters estimated from different environments may be quite different even when we use the same parametric model. This variation in parameters arises because the parameters on included features capture two components: first, the partial effect of the included features on the expected value of outcome, and second, a function that depends on the correlation between included and omitted features, as well the distribution of outcomes conditional on omitted features. We consider the the problem of making predictions when that second component of estimated parameters is unstable across environments. In that case, we prefer to find an estimator that eliminates the second component, even though including it improves prediction for test sets that are similar to the training data. We look for an algorithm that is effective when the analyst does not know which feature is stable feature and which is not.
One way to address the problem of stable prediction in such a setting is to isolate the impact of each individual feature. One method commonly used in the causal literature is covariate balancing [1, 8, 9], which essentially estimates the impact of the target feature by reweighting the data so that the distribution of covariates is equalized across different values of the target feature. This literature usually focuses on the case where there is a single, pre-specified feature of interest and other features are considered to be “confounders”. In this paper, we consider the case where there are potentially many stable features, and propose a novel Decorrelated Weighting Regression (DWR) algorithm for stable prediction with model misspecification and agnostic distribution shift by jointly optimizing a variable decorrelation regularizer and a weighted regression model. Specifically, the variable decorrelation regularizer constructs sample weights to reduce correlation among covariates and allows the weighted regression to approximately isolate the effect of each variable. The weighted regression model with those sample weights might perform worse than standard methods when predicting in the test data with similar distribution to the training, but it will do better across unknown test data with distribution shift from the training. Using both empirical experiments and theoretical analysis, we show that our algorithm outperforms alternatives in parameter estimation and stability in prediction across unknown test data.
This paper has three main contributions: 1) we investigate the problem of stable prediction with model misspecification and agnostic distribution shift. The problem setting is more general and practical than prior work. 2) We propose a novel DWR algorithm to jointly optimize variable decorrelation and weighted regression to address the stable prediction problem. 3) We conduct extensive experiments in both synthetic and real-world datasets to demonstrate the advantages of our algorithm on stable prediction problem.
Problem and Our Algorithm
In this section, we first give the formulation of stable prediction problem, then introduce the details of our algorithm.
Stable Prediction Problem
Let denote the space of observed features and denote the outcome space. We define an environment
to be a joint distributionon , and let denote the set of all environments. In each environment , we have dataset , where
are predictor variables andis a response variable. The joint distribution of features and outcomes on can change across environments: for .
In this paper, our goal is to learn a predictive model for stable prediction with model misspecification and agnostic distribution shift. To measure its performance on stable prediction problem, we adopt the and in  as:
where refers to the number of test environments, and represents the Root Mean Square Error of a predictive model on dataset . Actually, and
refer to the mean and variance of predictive error over all possible environments.
Then, the stable prediction problem  is defined as:
Problem 1 (Stable Prediction)
Given one training environment with dataset , the task is to learn a predictive model to predict across unknown environment with not only small but also small .
Letting , we define as stable features, and as unstable features with Assumption 1:
There exists a stable function f(s) such that for all environment , .
Assumption 1 can be guaranteed by . Thus, we can address the stable prediction problem by developing a predictive model that learns the stable function . But we have NO prior knowledge on which features are stable and which are unstable.
All stable features are observed.
Under Assumption 2, model misspecification will be induced when estimating an outcome function if the model omits some nonlinear transformations and interaction terms of the stable features. Suppose that the true stable function and in environment is given by:
where and . We assume that the analyst misspecifies the model by omitting and uses a linear model for prediction.
Under Assumption 1, the distribution shift across environments is mainly induced by the variation in the joint distribution over
. Simple linear regression may estimate nonzero effects of unstable featureswhen is correlated with the omitted variables . For OLS, we have
where is sample size, and . To simplify notation, we remove the environment variable from , , , .
If or in Eq. (4), will be biased, resulting in the biased estimation on in Eq. (5). And its prediction will be very unstable since the correlation between and (or ) might vary across testing environments. Hence, to increase the stability of prediction, we need to precisely estimate the parameters of by removing the correlation between and (or ) on training data, that is let and .
Notations. In our paper, refers to the sample size, and
is the dimensions of variables. For any vector, let , and . For any matrix , we let and represent the sample and the variable in , respectively.
In this subsection, we introduce our variable decorrelation regularizer to reduce the correlation between and (or ) in the training environment.
If are mutually independent with mean 0, then and .
) imply that if the covariates are mutually independent, we can unbiasedly estimate parametereven is omitted. This motivates our regularizer.
From , we know variables and are independent if for all .111In empirical applications, we can discretize and to satisfy the sufficient condition in . Inspired by the weighting methods in the causal literature [1, 5, 9], we propose to make and become independent by reweighting samples with weights , which can be learnt with the following objective function:
where are sample weights, and
is the corresponding diagonal matrix. In practice, however, it will not be feasible to attain the objective that all the moments of variables in the objective function from Eq. (6) are equal to zero. Fortunately, from Eq. (4) and Eq. (5) we know that reducing correlation among the first moments of the variables can help to improve the precision of parameter estimation and the stability of predictive models, and in practice the analyst can include high-order moments, for example, polynomial functions of covariates to further improve stability.
In this paper, we focus on variables’ first moment and propose to de-correlate all the predictors by sampling reweighting in the training environment. Specifically, we propose a variable decorrelation regularizer for learning that sample weight as follows:
where means all the remaining variables by removing the variable in .222We obtain in experiment by setting the value of variable in as . The summand represents the loss due to correlation between variable and all other variables . Note that, only first moment is considered in Eq. (7), but it is sufficient for variables decorrelation. And higher-order moments can be easily incorporated.
The following theoretical results (proved in the supplementary material) show that our variable decorrelation regularizer can make the variables in become mutually uncorrelated by sample reweighting, hence reduce the correlation among covariates in the training environment and improve the accuracy on parameter estimation.
With , we can denote the loss in Eq. (7) as:
If the number of covariates is fixed, then there exists a sample weight such that
with probability are the Kernel density estimators.
with probability. In particular, a solution to Eq. (9) is , where and
are the Kernel density estimators.333In detail, , where is a kernel function and is the bandwidth parameter for covariate ; and , where is a multivariate kernel function, and .
where for some constant .
Then, we have following theorem on our variable decorrelation regularizer in Eq. (10).
The solution defined in Eq. (10) is unique if , and for some constant .
Property 1. When is fixed, , , and , the variables in become uncorrelated by sample reweighting with . Hence, correlation between and in the training environment will be removed.
Extensive empirical experiments demonstrate that the correlation between and will also be reduced by our regularizer. In summary, the proposed variable decorrelation regularizer in Eq. (10) can learn a unique optimal sample weights that can de-correlate the variables , and thus improve the accuracy in parameters estimation and stability in prediction.
Decorrelated Weighting Regression
With the learned sample weights from variable decorrelation regularizer in Eq. (10), one can run weighted least square (WLS) to estimate the regression coefficient as:
The is expected to have less bias than under Property 1, since sample reweighted by de-correlates variables in .
By combining the objective functions of the variable decorrelation regularizer in Eq. (10) and the weighted regression in Eq. (11), we propose a Decorrelated Weighted Regression (DWR) algorithm to jointly optimize sample weights and regression coefficient as follows:
where denotes the sample size, refers to the dimension of variables . and represent the sample and the variable in , respectively. The term constrains each sample weight to be non-negative. With term , we reduce the variation of the sample weights. The term avoids all sample weights to be .
Optimization and Analysis
To optimize our DWR algorithm in Eq. (Decorrelated Weighting Regression), we propose an iterative method. Firstly, we initialize sample weights for each sample and regression coefficient . Once the initial values are given, in each iteration, we first update by fixing , then update by fixing until the objective function in Eq. (Decorrelated Weighting Regression) converges. The whole algorithm is summarized in Algorithm 1.
In our DWR algorithm, the main time cost is to calculate the value of loss function and update parameters and in each iteration. The complexity of calculating the loss function is , where is the sample size and refers to the dimension of observed variables. The complexity of updating parameter is also . The complexity of updating parameter is .
In total, the complexity of each iteration in Algorithm 1 is .
In this section, we check the performance of our algorithm with experiments on both synthetic and real-world datasets.
To evaluate the performance of stable prediction, we use , , and
as evaluation metrics. Their definitions ofand are listed as follows:
where is sample size, and refer to the predicted and true outcome for sample .
where and represent the estimated and true regression coefficients.
Experiments on Synthetic Data
: In this setting, and are independent, but could be dependent with each other. Hence, we generate
with independent Gaussian distributions with the help of auxiliary variablesas following:
where the number of stable variables and the number of unstable variables . represents the variable in .
: In this setting, the stable features are the causes of unstable features . We first generate dependent stable features with Eq. (16). Then, we generate unstable features based on : , where we let . The function returns the modulus after division of by .
: In this setting, unstable features are the causes of stable features . We first generate the unstable features with Eq. (15). Then, we generate the stable features based on : , where we let .
To test the performance with different forms of missing nonlinear and interaction terms, we generate the outcome from a polynomial nonlinear function () and an exponential one ():
where , , and .
Generating Various Environments
To test the stability of all algorithms, we need to generate a set of environments, each with a distinct joint distribution , while preserving Assumption 1 (and in particular, ). Specifically, we generate different environments in our experiments by varying . For simplification we only change on a subset of unstable features , where the dimension of is .
Note that corresponds to positive unstable correlation between and , while refers to the negative unstable correlation between and . The higher the value of , the stronger correlation between and . Different value of refers to different environments, hence we can generate different environments by varying .
In experiments, we evaluate the performance of all algorithms from two aspects, including accuracy on parameter estimation and stability on prediction across unknown test data. To measure the accuracy of parameter estimation, we train all models on one training dataset with a specific bias rate . We carry out model training for 50 times independently with different training data from the same bias rate , and report the mean and variance of on stable features and unstable features . To evaluate the stability of prediction, we test all models on various test environments with different bias rate . For each test bias rate , we generate 50 different test datasets and report the mean of RMSE. With RMSE from all test environments, we report AverageError and StabilityError to evaluate the stability of prediction across unknown test environments.
|Scenario 1: varying sample size|
|Scenario 2: varying variables’ dimension|
|Scenario 3: varying bias rate on training data|
Before reporting the experimental results, we demonstrate the Pearson correlation coefficients between any two variables on both raw data and the weighed data by our algorithm in Figure 2. From the figures, we can find that in the raw data, the unstable features is correlated with some stable features , and highly correlated with both omitted nonlinear term and outcome . Hence, the estimated coefficient of in the baselines would be large, which should be in a correctly specified model, leading to unstable prediction. In the weighted data, the sample weights learnt from our algorithm can clearly remove the correlation among predictors . Moreover, the unstable correlation between and are significantly reduced, which is helpful to reduce the unstable correlation between and , and then the correlations between stable features and conditional on are enhanced. Hence, our algorithm can estimate the coefficient of both and more precisely. This is the key reason that our algorithm can make more stable predictions across unknown test environments.
We report the results of parameter estimation and stable prediction under setting with in Figure 3 and Table 1. To save space, the experimental results of settings and with , and results with are reported in online Appendix. From the results, we have following observations and analysis:
OLS cannot address the stable prediction problem. The reason is that OLS is biased on both and estimation as we discussed in the theoretical section. Moreover, OLS will often predict large effects of the unstable features, which leads to instability across environments.
Lasso, Ridge and IILasso perform even worse than OLS, since their regularizers will generally estimate larger coefficients on the unstable features . For example, Lasso selects a only a subset of predictors and exacerbates the omitted variables problem that already exists in our basic setup.
Comparing with baselines, our algorithm achieves more stable prediction across different settings. By reducing the correlation among all predictors, our algorithm avoids using unstable features to proxy for omitted nonlinear functions of the stable features, ensuring less bias in the estimation of the effect of both stable features and unstable features. Hence, improve the stability of prediction.
The performance of our algorithm is worse than baseline when on test data in Fig. (c)c, but much better than baselines when . This is because the correlations between and are similar between training data () and test data when , and that correlation can be exploited in prediction; in this setting, is useful to proxy for omitted functions of . However, when , using for prediction creates too much instability.
By varying the sample size , dimension of variables , training bias rate and the form of missing nonlinear and interaction terms, our algorithm is consistently outperform than baselines on parameter estimation and stable prediction across unknown test data.
Overall, our proposed DWR algorithm can be applied to address the problem of stable prediction with model misspecification and agnostic distribution shift.
In our DWR algorithm, we have some hype-parameters, including for constraining the sparsity of regression coefficient, for constraining the error of decorrelation regularizer, for constraining the variance of the sample weights, and for constraining the sum of sample weights to . In our experiments, we tuned these parameters with cross validation by grid searching, and each parameter is uniformly varied from . In Figure 4, we displayed and with respect to . From the figures, we can find that when , and monotonically decrease as we increase the value of hype-parameter . But when , those errors will slightly increase as we keep increasing the value of hype-parameter .
Experiments on Real World Data
Datasets and Experimental Setting
We collected air pollutant data and meteorological data from the U.S. EPA’s Air Quality System (AQS) database,444https://www.epa.gov/outdoor-air-quality-data which has been widely used for model evaluation [17, 18]. The air pollutant data in this study is PM, and the meteorological variables are those would affect the air pollutant concentrations, including air temperature, relative humidity, pressure, wind speed and direction.
In our experiments, we let the outcome variable be pollution PM, and set the meteorological features as the observed variables . To test the stability of all algorithms, we collected data from 10 different states in the U.S., where the states correspond to the different environments from the theory. Considering a practical setting where a researcher has a single data set and wishes to train a model that can then be applied to other related settings, in our experiments, we trained all models with data from State 1, validated with data from States 1 to 4, finally tested them on all 10 States.
To demonstrate the distribution difference between any two environments and , we adopt the distribution distance555Variable’s distribution can be uniquely determined by all the collections of its moments. Here, we only consider the first moment. Other metrics can also be applied to measure distribution distance, for example, KL-divergence. We leave it in the future work. between observed variables as a metric with following definition:
where refers to the dimension of variables, and represents the mean value of variables in environment .
We report the results of RMSE on air quality prediction over all 10 States in Fig. (a)a, where we merged OLS method into Lasso by allowing its hype-parameter to be during model training. The results show that the performance of our algorithm is worse than baselines when the distribution distance between training and test environments is small; in that case, we introduce variance by reweighting the data away from the distribution that approximates both training and test sets. But our algorithm’s performance improves relative to the baseline and ultimately becomes better than baseline as the distribution distance increases.
To explicitly demonstrate the advantage of our proposed algorithm, we report Average_Error and Stability_Error in Fig. (b)b. The results show that our algorithm makes the most stable prediction with agnostic distribution shift on test data.
In this paper, we focus on how to facilitate a stable prediction across unknown test data, where we are concerned about two problems that together lead to instability: model misspecification, and agnostic distribution shift between training and test data. We proved that our algorithm can improve the accuracy of parameter estimation and stability on prediction from both theoretical analysis and empirical experiments. The experimental results on both synthetic and real-world datasets demonstrate that our algorithm outperforms the baselines for stable prediction across unknown test environments, when the correlation among covariates varies substantially across those environments.
This work was supported by National Key Research and Development Program of China (No. 2018AAA0102004, No. 2018AAA0101900), National Natural Science Foundation of China (No. 61772304, No. 61521002, No. 61531006, No. U1611461), Beijing Academy of Artificial Intelligence (BAAI). Susan Athey’s research was supported by Sloan foundation and Office of Naval Research grant N00014-17-1-2131. Bo Li’s research was supported by the Tsinghua University Initiative Scientific Research Grant, No. 20165080091; National Natural Science Foundation of China, No. 71490723 and No. 71432004; Science Foundation of Ministry of Education of China, No. 16JJD630006. All authors of this paper are corresponding authors. All opinions, findings, and conclusions in this paper are those of the authors and do not necessarily reflect the views of the funding agencies.
-  (2018) Approximate residual balancing: debiased inference of average treatment effects in high dimensions. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 80 (4), pp. 597–623. Cited by: Introduction, Variable Decorrelation.
-  (2010) A theory of learning from different domains. Machine learning 79 (1-2), pp. 151–175. Cited by: Introduction.
-  (2009) Discriminative learning under covariate shift. Journal of Machine Learning Research 10 (Sep), pp. 2137–2155. Cited by: Introduction.
-  (2006) When does e(xkâ·yl)=e(xk)â·e(yl) imply independence?. Statistics & Probability Letters 76 (11), pp. 1111–1116. Cited by: Variable Decorrelation, footnote 1.
-  (2018) Covariate balancing propensity score for a continuous treatment: application to the efficacy of political advertisements. The Annals of Applied Statistics 12 (1), pp. 156–177. Cited by: Variable Decorrelation.
-  (1970) Ridge regression: biased estimation for nonorthogonal problems. Technometrics 12 (1), pp. 55–67. Cited by: 3rd item.
-  (2018) Stable prediction across unknown environments. In Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 1617–1626. Cited by: Introduction, Stable Prediction Problem, Stable Prediction Problem.
-  (2017) Treatment effect estimation with data-driven variable decomposition. In Thirty-First AAAI Conference on Artificial Intelligence, pp. 140–146. Cited by: Introduction.
-  (2017) Estimating treatment effect in the wild via differentiated confounder balancing. In Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 265–274. Cited by: Introduction, Variable Decorrelation.
-  (2013) Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. Cited by: Introduction.
-  (2009) A survey on transfer learning. IEEE Transactions on knowledge and data engineering 22 (10), pp. 1345–1359. Cited by: Introduction.
Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5), pp. 947–1012. Cited by: Introduction.
-  (2018) Invariant models for causal transfer learning. The Journal of Machine Learning Research 19 (1), pp. 1309–1342. Cited by: Introduction.
-  (2018) Causally regularized learning with agnostic data selection bias.. In ACM Multimedia, pp. 411–419. Cited by: Introduction.
-  (2017) Independently interpretable lasso: a new regularizer for sparse regression with uncorrelated variables. arXiv preprint arXiv:1711.01796. Cited by: 4th item.
-  (1996) Regression shrinkage and selection via the lasso. Journal of the Royal Statistical Society. Series B (Methodological), pp. 267–288. Cited by: 2nd item.
-  (2017) Decadal application of wrf/chem for regional air quality and climate modeling over the us under the representative concentration pathways scenarios. part 1: model evaluation and impact of downscaling. Atmospheric environment 152, pp. 562–583. Cited by: Datasets and Experimental Setting.
-  (2018) A machine learning approach for air quality prediction: model regularization and optimization. Big Data and Cognitive Computing 2 (1), pp. 5. Cited by: Datasets and Experimental Setting.