1 Introduction
Machine learning models are typically following a consistency assumption that the training and testing datasets are generated from the same data distribution (i.e., i.i.d hypothesis). Under i.i.d hypothesis, a trained model can be directly applied to the testing dataset, and its performance should be equivalent to the training dataset. Although this approach is empirically proven to be highly successful in many public datasets, it is considered flawed in practice. In many realworld applications, there is no guarantee that the training data and the unknown testing data will have the same distribution due to the data generation bias. This agnostic distribution shift between training and testing datasets can result in inaccurate parameter estimation and unstable performance across different test datasets. Let us consider a case in the image classification task: why an image of a dog on the grass can be labeled as ’dog’, but a cat on the grass should not. The reason is profound that there is a stable causality exist between certain essential stable features (features of the dog in this case) and the ground truth. This stable causality is consistent even if the unstable features (e.g., the feature of grass) are changed across different unknown environments. However, the data generation bias can lead to a statistical correlation between unstable features like grass and the ground truth ’dog’. The model trained in such dataset can be easily trapped into this statistical correlation for yielding lower training loss, and resulting in unstable performance across testing datasets with different distributions.
Therefore, how to mitigate this model misspecification problem, and learn a stable model have attracted lots of attention from the community. (Pan et al., 2018; Wang et al., 2019; Shen et al., 2019; Kuang et al., 2020) One of the most straightforward thoughts to solve this model misspecification problem is to take advantage of the prior knowledge of the test data. Based on that, many domain adaption (Chen et al., 2020)
and transfer learning methods
(Pan et al., 2018; Wang et al., 2019) were proposed recently to correct the distribution shift by using the prior knowledge of the test dataset. However, in the agnostic distribution shift problem, it is unable to obtain the prior knowledge of the test data or the true distribution. Therefore, these works are not feasible when dealing with the agnostic distribution shift problem we focus on in this work.More recently, many causalitybased methods (Shen et al., 2018, 2019; Kuang et al., 2020) were proposed. Most of them are sample reweighting based methods. Shen et al. (2019) proposed a sample reweighting method to address the collinearity problem among input variables caused by agnostic distribution shift. Kuang et al. (2020) proposed a causalitybased sample reweighting based feature decorrelation method. However, these sample reweighing based methods requires a sample reweighing matrix whose parameter number is proportional to the training sample number. Hence, these works are both computationally and memory intensive. This disadvantage limits the scalability of these methods in machine learning tasks with a large number of training data. Moreover, these works are limited to linear models, but not to deeplearning based nonlinear models.
In this work, we investigate the stable learning problem under the agnostic distribution shift. Specifically, we propose a novel Causalitybased Feature Rectification (CFR) method to address the model misspecification problem caused by the agnostic distribution shift. Our proposal’s motivation is that the stable causality between the stable features and the ground truths is consistent even if it is partly omitted due to the agnostic distribution shift. Therefore, we can use the correlations between omitted stable features and other features as a proxy of these omitted causality. We propose to use a feature rectification weight to reconstruct the causality graph using these correlations as a proxy. We further propose an algorithm to jointly optimize the feature rectification weight and a regressor (or classifier), in which the feature rectification weight matrix is used to rectify the input features to aganist the impact of agnostic distribution shift. Our proposal is a feature pretreatment method, and it can integrate with most of the commonly used linear regression and image classification methods, including deeplearning based models. The experimental results demonstrate that the stability of our proposal outperforms all baselines and stateoftheart stable learning methods in both synthetic and realworld datasets.
We summarize the contributions as following:
(1) We investigate the stable learning problem under the agnostic distribution shift in this work. Specifically, we focus on the model misspecification problem under agnostic distribution shift.
(2) We propose a novel Causalitybased Feature Rectification (CFR) method to improve the stability of the model across different test datasets. Specifically, we propose a feature rectification weight to reconstruct the causality graph of features and help to learn the true underlying model under the agnostic distribution shift. Unlike previous works, our proposal is not restricted to linear models, and can also be applied to deeplearning based models.
(3) We conduct linear regression and image classification experiments on both synthetic and realworld datasets to validate our proposal’s performance. The experimental results demonstrate that our proposal can achieve stateoftheart performance on both synthetic and realworld datasets.
2 Related Works
2.1 Causality based Methods
Most recently proposed causalitybased methods are based on sample reweighting, which not directly change the biased sample features, but shifting the training dataset’s distribution by varying the importance of the samples. Shen et al. (2018)
proposed a Causally Regularized Logistic Regression model to address the agnostic distribution shift in Logistic Regression tasks. But it did not consider the model misspecification problem, and its algorithm was restricted to the predictive setting with binary predictors and binary response variable.
Shen et al. (2019) proposed a sample reweighting method to address the collinearity among input variables caused by the agnostic distribution shift. Kuang et al. (2020) proposed a causalitybased feature sample reweighting method, which shifts the features mutually independent with mean 0.However, these works are restricted to regression tasks and binary classification tasks. Meanwhile, these sample reweighting based works require a large sample reweighting matrix, and its parameter number is proportional to the training data number. Hence, these works are not feasible for tasks with massive training data. Different from them, our proposal addresses the model misspecification problem by rectifying the correlation between features. Moreover, the parameter number of our proposal is proportional to the feature dimension. In most machinelearning tasks, the feature dimension is less than thousands. Hence, our proposal is more feasible than previous works for tasks with a large training dataset.
2.2 Noncausality based Methods
In addition to causality based methods, a variety of domain adaptation (Chen et al., 2020; Liu and Ziebart, 2014; Zadrozny, 2004) and transfer learning methods Pan et al. (2018); Wang et al. (2019) were proposed to address the noni.i.d problem. Most of these methods handle the distribution shift between training and testing datasets by aligning the training dataset to the target dataset or vice versa. To achieve that, these methods require prior knowledge of the distribution of the target domain. However, in the agnostic distribution shift problem, there is no prior knowledge about the test datasets. Hence, these methods cannot be applied to the agnostic distribution shift problem we focused on in this work.
Except for domain adaptation and transfer learning methods, there are also some domain generalization (Muandet et al., 2013) and invariant causal prediction (Peters et al., 2016) methods were proposed recently to address the distribution shift problem. These works exploring the invariant structure between predictors and the response variables in multiple training data sets to make prediction (Kuang et al., 2020). However, these works cannot handle the distribution shifts that are not observed in the training data.
3 Causalitybased Feature Rectification
3.1 Problem Formulation and Preliminaries
Stable Learning.
Given a dataset collected in environment , where is the observed feature in environment , is the corresponded ground truths. The enviroment
is defined as a joint distribution
on and is the collection of all environments. The stable learning task is to learn a preditive model on training dataset which can achieve uniformly small error across unknown environments in (Kuang et al., 2020).Notations.
In this work, we denote as the observed feature dimenstion, as the sample size. For a martrix , we let and represent the th row and the th column in respectively. and denote the remaining matrixs by removing the th row and
th column respectively. For a vector
, we let and .3.2 Causality Stable Feature
According to the stability of the causality between ground truths and features , we can divide into stable features and unstable features under Assumption 1 and Assumption 2:
Assumption 1.
(Kuang et al., 2020) There exists a stable function , so that in any environment we have .
Assumption 1 can be guaranteed by . So the model misspecification problem under agnostic distribution shift can be addressed by learning the stable function while ignoring the interference of unstable features. However, in realworld applications, there are no prior knowledge on which feature is stable and which is not. And there is no guarantee that no causality exits between features.
Assumption 2.
All stable feature are observed.
Under Assumption 2, the stable learning problem can be addressed by rectifying the correlation between unstable features and the omitted causality.
3.3 Linear Regression under Agnostic Distribution Shift
We use linear regression task with model misspecification as an example to illustrate our method in this work. We assume the true underlying model is:
(1) 
where is the true stable causality function and obeys the Gauss–Markov assumptions (Henderson, 1975). Hence,
is uncorrelated to each other, have equal variances and expectation value 0. Under this setting, an ordinary least squares (OLS) estimator is the best linear unbiased estimator (BLUE) of
.We then suppose in an environment , finite samples are sampled from the environment and formed dataset . Due to the data generation bias, the distribution of is deviated from the true distribution . Part of the causal relationship between the stable features and the ground truths being omitted due to this agnostic distribution shift. If a simple linear regression is applied, it may trapped in the statistical correlation between unstable features and the omitted causality. This statistical correlation between unstable features and the omitted part is the reason of model misspecification (Kuang et al., 2020). The model under environment is given by:
(2) 
where is the respond of the omitted causality, which statistical correlated to .
Since there is a strong correlation between and , then a proxy function can be learned to represented using in environment . Based on this idea, we propose the feature rectification regularizer.
3.4 Feature Rectification
We propose to learn a feature rectification weight matrix by reconstructing features using all other features. The weight matrix learns the correlations between features. By doing so, the omitted stable feature patterns correlated to the unstable features can be represented by the weight matrix. The objective function as following:
(3) 
where is the feature dimension, , means all the remaining feature by removing th feature, and means the all the remaining weights by removing th column of weights.
Suppose a training dataset with sample size , we can denote the loss in Equation 3 as:
(4) 
3.5 Application in Linear Regression
To illustrate our method, we using an OLS estimator togther with our proposal to estimate the regression coefficients:
(5) 
As we discussed above, the weight matrix is trained by reconstructing the causality graph between features using the training dataset. The residual of the reconstruction is the independent part of the features which is independent from others. The OLS estimator is compelled to learn the residual.
We claim that the second term () helps the regressor to learn the true underlying model and improves the stability of the model across different test datasets. The experimental results have shown that our proposal worked when dealing with linear regression and image classification tasks.
The Loss function for optimizing Equation
5 is:(6) 
where is the sample number of the training dataset. As for image classification tasks with deeplearning based models, the loss function is:
(7) 
where denotes a training image sample, denotes the output of the feature extractor, denotes the output of the classifier and is a CrossEntropy loss function.
Based on Equation 4 and Equation 6, we propose the Causalitybased Feature Rectification algorithm to joint optimize feature rectification weight and model parameters as shown in Algorithm 1:
3.5.1 Complexity Analysis
The main time costs for our proposal are to calculate the value of loss function and update parameters and . We used DWR (Kuang et al., 2020) as a reference to evaluate the complexity of our work. The computation complexity of Algorithm 1 is , where is the sample size and is the feature dimension. The complexity of the stateoftheart method DWR is .
4 Experiments
We conduct extensive experiments to evaluate our proposal, comparing with several baselines and the stateoftheart works on both synthetic and realworld datasts.
4.1 Datasets
4.1.1 Realworld Datasets
We use the benchmark dataset CIFAR10 (Krizhevsky et al., 2009)
to evaluate the performance of our proposal in image classification tasks. CIFAR10 is an established computervision dataset used for object recognition. It is a subset of the 80 million tiny images dataset and consists of 60,000 32x32 color images containing one of 10 object classes, with 6000 images per class. The size of image in CIFAR10 is
pixels. During training, images are random cropped and horizontal flipped with a probability 0.5.
4.1.2 Synthetic Datasets
Under Assumption 1, there are three kinds of relationship between and , including , and . The denotes stable features and denotes unstable features. We construct synthetic datasets following the setting of previous stateoftheart work DWR (Kuang et al., 2020). This synthetic dataset mimic the model misspecification problem in which part of the causality is omitted due to the data generation bias. The data generation bias causes a statistical correlations between unstable features and the omitted causality.
We mimic the situation in this synthetic dataset. We generate input features
with independent Gaussian distributions with the help of an auxiliary features
:where and is the number of stable features and unstable features respectively, denotes the th feature in . The setting about and can be found in the supplementary materials.
To test the performance with different forms of missing nonlinear terms, we generate the ground truths from a polynomial nonlinear function () and an exponential function ():
where , and .
To test the stability, we generate a set of environments, each with a distinct joint distribution , while preserving . To achieve that, we generate environments by varying on a subset of . Following the setting used by DWR (Kuang et al., 2020), we vary via biased sample selection with a bias rate . For each sample, the probability of being selected is defined as , where . If , otherwise .
4.2 Evaluation Metrics
We use RMSE, , Average Error (AE) and Stability Error (SE) to evaluate the performance of our proposal. The between a learned coefficient and the true coefficient is defined as , where is feature dimension of . We report both mean and variance of of 50 independent experiments. The AE and SE proposed by Kuang et al. (2020) is define as:
(8) 
4.3 Compared Methods
We use four methods as baselines in this work, including OLS, Lasso (Tibshirani, 1996)
(Hoerl and Kennard, 1970) and DWR (Kuang et al., 2020). DWR is the previous stateoftheart causalitybased sample reweighting method. We used the official implementation of DWR provided by the authors.Our proposal has three advantages over the previous stateoftheart works. First, the previous stateofthearts are restricted to the linear model and not to the deep nonlinear model. Our proposal can be used for both linear and depth models. Second, the parameter number of previous stateofthearts are proportional to the sample size of the training dataset, so it is not suitable for scenarios with a large training dataset. However, the parameter number of our proposal is invariant of sample size. Meanwhile, the performance of previous stateofthearts are sensitive to the sample size, but ours is not. Third, the previous stateofthearts require to calculated the loss of all training samples before optimizing parameters, which is both computational and memory intensive. Different from them, our proposal uses minibatch training scheme, which is more practical.
4.4 Experiments on Synthetic Datasets
We evaluate the performance of all methods by comparing the accuracy on parameter estimation and stability on prediction across unknonw test datasets. To evaluate the parameter estimation accuracy, we train all models on the same training dataset with a specific bias rate . We repeat this training process for 50 times independently with different training data from the same bias rate , and report the mean and variance of on since the is the source of model misspecification in this synthetic dataset. To evaluate the prediction stability, we test all models on several test environments with various bias rate . For each test bias rate, we generate 50 different test datasets and report the mean of RMSE. Using the RMSE results, we further calculate the AE and SE to evaluate the prediction stability across different test environments. As for hyperparameters, we use random seed 47 for all experiments. The init learning rate for optimizing the weight matrix of DWR and our proposal is 0.005, and 0.001 for optimizing the linear regressor.
Scenario 1: varying sample size n  

n,p,r  n=1000,p=10,r=1.7  n=2000,p=10,r=1.7  n=4000,p=10,r=1.7  
Methods  OLS  Lasso  Ridge  DWR  Our  OLS  Lasso  Ridge  DWR  Our  OLS  Lasso  Ridge  DWR  Our 
0.099  0.102  0.099  0.066  0.027  0.097  0.101  0.097  0.060  0.025  0.097  0.101  0.097  0.057  0.016  
AE  0.604  0.639  0.603  0.519  0.629  0.583  0.617  0.583  0.509  0.613  0.587  0.621  0.587  0.505  0.569 
SE  0.254  0.285  0.254  0.103  0.086  0.236  0.267  0.236  0.110  0.071  0.236  0.267  0.236  0.114  0.089 
Scenario 2: varying feature dimension p  
n,p,r  n=2000,p=10,r=1.7  n=2000,p=20,r=1.7  n=2000,p=40,r=1.7  
Methods  OLS  Lasso  Ridge  DWR  Our  OLS  Lasso  Ridge  DWR  Our  OLS  Lasso  Ridge  DWR  Our 
0.097  0.101  0.097  0.060  0.025  0.070  0.080  0.070  0.066  0.027  0.044  0.047  0.044  0.038  0.013  
AE  0.583  0.617  0.583  0.509  0.613  0.612  0.720  0.612  0.550  0.546  0.538  0.618  0.538  0.519  0.471 
SE  0.236  0.267  0.236  0.110  0.071  0.319  0.408  0.319  0.232  0.071  0.312  0.370  0.312  0.297  0.082 
Scenario 3: varying bias rate r on training data  
n,p,r  n=2000,p=20,r=1.5  n=2000,p=20,r=1.7  n=2000,p=20,r=2.0  
Methods  OLS  Lasso  Ridge  DWR  Our  OLS  Lasso  Ridge  DWR  Our  OLS  Lasso  Ridge  DWR  Our 
0.059  0.067  0.059  0.060  0.010  0.070  0.080  0.070  0.066  0.027  0.079  0.091  0.079  0.077  0.023  
AE  0.519  0.590  0.519  0.548  0.497  0.612  0.720  0.612  0.550  0.546  0.660  0.781  0.660  0.613  0.618 
SE  0.220  0.297  0.220  0.197  0.031  0.319  0.408  0.319  0.232  0.071  0.364  0.447  0.364  0.303  0.119 
As shown in Figure 1, we visualized the results with setting with . We can notice that our algorithm can achieves the lowest parameter estimation error on compare with all baselines. It shows that our proposal can significantly mitigate the model misspecification caused by feature correlations. Meanwhile, our proposal can achieve the lowest SE across different , which indicates the model trained by our proposal is more stable than all baselines including the stateofthe art method.
As shown in Table 1, our method achieved the lowest and SE in all experiments comparing with all baselines including stateoftheart method DWR. It demonstrates that our proposal can effectively mitigate the model misspecification problem and improve the stability of the model across different test dataset. Meanwhile, we can notice that the performance of our proposal is stable when the sample size changing, but the performance of the stateoftheart method DWR is affected by the sample size as shown in Table 1. It shows that our method is more robust when using different training sample size.
These observations lead to the conclusion that our method can achieve better stability across different test datasets. Moveover, the experimental results also demonstrate that our proposal is more stable when dealing with limited sample numbers.
The experimental results of , with , and results of are reported in the supplementary materials.
4.5 Experiments on Realworld Datasets
We then evaluate the performance of our proposal in image classification tasks. We conduct experiments on the benchmark image classification dataset CIFAR10. We use a ResNet50 model as the feature extractor to extract feature embedding of the input images, and we then apply a classifier consists of one linear layer to predict the label of the input image. Similarly as Algorithm 1, we use a feature rectification weight matrix to learn the causality graph of the extracted feature embedding. The loss function for optimizing the weight matrix have shown in Equation 7.
As for baselines, we compare our proposal with the naive ResNet50 and the ResNet50 with DWR (Kuang et al., 2020)
. Note that the original DWR is restricted to linear regression tasks, so we applied the key idea of DWR to the image classification tasks. The original loss function of DWR requires to calculate loss over all training samples and then carry out backpropagate, which is not feasible in image classification tasks with using GPU accelerating. As a tradeoff, we update part of the sample weight of DWR in each minibatch.
As for hyperparameters, we using init learning rate of 0.1 for optimizing all models and 5e6 for optimizing the weight matrix of our proposal. All models are trained for 350 epochs. The learning rate is reduced after the 150 and 160 epoch by multiplying 0.1. The random seed is 47 for all experiments.
Methods  CIFAR10 

ResNet50  94.98 
ResNet50+DWR  95.09 
ResNet50+CFR  95.59 
The results have shown in Table 2. Our proposal can improve the accuracy of the deeplearning based model ResNet50 in all datasets. These experiemnts domenstrate that our proposal can be used in both linear regression and image classification tasks.
5 Conclusion
In this work, we address the model misspecification problem under the agnostic distribution shift by proposing a novel Causalitybased Feature Rectification (CFR) method. Experiments on both synthetic and realworld datasets demonstrate that our proposal helps improve the performance of the baseline models, and outperforms the stateoftheart stable learning methods. Our method is a general data pretreatment method, which can be seamlessly integrated into classical linear regression models and classification models. It provides a unified approach to alleviate the problem of model misspecification problem under the agnostic distribution shift. Unlike previous works, our proposal is not restricted to linear models, and can also be applied to deeplearning based models.
Broader Impact
Our proposal is a data pretreatment method which causes no explicit ethical problems. As for benefits, our proposal can help to improve the stability of the model on unknown test datasets, and reduce the data annotation costs for finetuning model into new scenarios. It can be applied to most deeplearning based models and linaer models in many application scenarios like Quantitative financial analysis, social media and time series based forecasting tasks. Moreover, our proposal helps to leverages biases in the data. Thus the likely beneficiaries are companies that have business suffers the model misspecification problem. The consequence of the failure of the proposal is a poor classification or regression performance. Our work may affect the interests of companies whose primary business is data annotations.
References
 Adversariallearned loss for domain adaptation. arXiv preprint arXiv:2001.01046. Cited by: §1, §2.2.
 Best linear unbiased estimation and prediction under a selection model. Biometrics, pp. 423–447. Cited by: §3.3.
 Ridge regression: biased estimation for nonorthogonal problems. Technometrics 12 (1), pp. 55–67. Cited by: §4.3.
 Learning multiple layers of features from tiny images. Cited by: §4.1.1.
 Stable prediction with model misspecification and agnostic distribution shift. CoRR abs/2001.11713. External Links: Link, 2001.11713 Cited by: §1, §1, §2.1, §2.2, §3.1, §3.2, §3.3, §3.5.1, §4.1.2, §4.1.2, §4.2, §4.3, §4.5.
 Robust classification under sample selection bias. In Advances in neural information processing systems, pp. 37–45. Cited by: §2.2.
 Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. Cited by: §2.2.
 MacNet: transferring knowledge from machine comprehension to sequencetosequence models. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. CesaBianchi, and R. Garnett (Eds.), pp. 6092–6102. External Links: Link Cited by: §1, §2.2.

Causal inference by using invariant prediction: identification and confidence intervals
. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5), pp. 947–1012. Cited by: §2.2.  Causally regularized learning with agnostic data selection bias. In Proceedings of the 26th ACM international conference on Multimedia, pp. 411–419. Cited by: §1, §2.1.
 Stable learning via sample reweighting. arXiv preprint arXiv:1911.12580. Cited by: §1, §1, §2.1.
 Regression shrinkage and selection via the lasso. Journal of the Royal Statistical Society: Series B (Methodological) 58 (1), pp. 267–288. Cited by: §4.3.
 A minimax game for instance based selective transfer learning. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’19, New York, NY, USA, pp. 34–43. External Links: ISBN 9781450362016, Link, Document Cited by: §1, §2.2.
 Learning and evaluating classifiers under sample selection bias. In Proceedings of the twentyfirst international conference on Machine learning, pp. 114. Cited by: §2.2.
Comments
There are no comments yet.