Machine learning models are typically following a consistency assumption that the training and testing datasets are generated from the same data distribution (i.e., i.i.d hypothesis). Under i.i.d hypothesis, a trained model can be directly applied to the testing dataset, and its performance should be equivalent to the training dataset. Although this approach is empirically proven to be highly successful in many public datasets, it is considered flawed in practice. In many real-world applications, there is no guarantee that the training data and the unknown testing data will have the same distribution due to the data generation bias. This agnostic distribution shift between training and testing datasets can result in inaccurate parameter estimation and unstable performance across different test datasets. Let us consider a case in the image classification task: why an image of a dog on the grass can be labeled as ’dog’, but a cat on the grass should not. The reason is profound that there is a stable causality exist between certain essential stable features (features of the dog in this case) and the ground truth. This stable causality is consistent even if the unstable features (e.g., the feature of grass) are changed across different unknown environments. However, the data generation bias can lead to a statistical correlation between unstable features like grass and the ground truth ’dog’. The model trained in such dataset can be easily trapped into this statistical correlation for yielding lower training loss, and resulting in unstable performance across testing datasets with different distributions.
Therefore, how to mitigate this model misspecification problem, and learn a stable model have attracted lots of attention from the community. (Pan et al., 2018; Wang et al., 2019; Shen et al., 2019; Kuang et al., 2020) One of the most straightforward thoughts to solve this model misspecification problem is to take advantage of the prior knowledge of the test data. Based on that, many domain adaption (Chen et al., 2020)
and transfer learning methods(Pan et al., 2018; Wang et al., 2019) were proposed recently to correct the distribution shift by using the prior knowledge of the test dataset. However, in the agnostic distribution shift problem, it is unable to obtain the prior knowledge of the test data or the true distribution. Therefore, these works are not feasible when dealing with the agnostic distribution shift problem we focus on in this work.
More recently, many causality-based methods (Shen et al., 2018, 2019; Kuang et al., 2020) were proposed. Most of them are sample reweighting based methods. Shen et al. (2019) proposed a sample reweighting method to address the collinearity problem among input variables caused by agnostic distribution shift. Kuang et al. (2020) proposed a causality-based sample reweighting based feature decorrelation method. However, these sample reweighing based methods requires a sample reweighing matrix whose parameter number is proportional to the training sample number. Hence, these works are both computationally and memory intensive. This disadvantage limits the scalability of these methods in machine learning tasks with a large number of training data. Moreover, these works are limited to linear models, but not to deep-learning based nonlinear models.
In this work, we investigate the stable learning problem under the agnostic distribution shift. Specifically, we propose a novel Causality-based Feature Rectification (CFR) method to address the model misspecification problem caused by the agnostic distribution shift. Our proposal’s motivation is that the stable causality between the stable features and the ground truths is consistent even if it is partly omitted due to the agnostic distribution shift. Therefore, we can use the correlations between omitted stable features and other features as a proxy of these omitted causality. We propose to use a feature rectification weight to reconstruct the causality graph using these correlations as a proxy. We further propose an algorithm to jointly optimize the feature rectification weight and a regressor (or classifier), in which the feature rectification weight matrix is used to rectify the input features to aganist the impact of agnostic distribution shift. Our proposal is a feature pretreatment method, and it can integrate with most of the commonly used linear regression and image classification methods, including deep-learning based models. The experimental results demonstrate that the stability of our proposal outperforms all baselines and state-of-the-art stable learning methods in both synthetic and real-world datasets.
We summarize the contributions as following:
(1) We investigate the stable learning problem under the agnostic distribution shift in this work. Specifically, we focus on the model misspecification problem under agnostic distribution shift.
(2) We propose a novel Causality-based Feature Rectification (CFR) method to improve the stability of the model across different test datasets. Specifically, we propose a feature rectification weight to reconstruct the causality graph of features and help to learn the true underlying model under the agnostic distribution shift. Unlike previous works, our proposal is not restricted to linear models, and can also be applied to deep-learning based models.
(3) We conduct linear regression and image classification experiments on both synthetic and real-world datasets to validate our proposal’s performance. The experimental results demonstrate that our proposal can achieve state-of-the-art performance on both synthetic and real-world datasets.
2 Related Works
2.1 Causality based Methods
Most recently proposed causality-based methods are based on sample reweighting, which not directly change the biased sample features, but shifting the training dataset’s distribution by varying the importance of the samples. Shen et al. (2018)
proposed a Causally Regularized Logistic Regression model to address the agnostic distribution shift in Logistic Regression tasks. But it did not consider the model misspecification problem, and its algorithm was restricted to the predictive setting with binary predictors and binary response variable.Shen et al. (2019) proposed a sample reweighting method to address the collinearity among input variables caused by the agnostic distribution shift. Kuang et al. (2020) proposed a causality-based feature sample reweighting method, which shifts the features mutually independent with mean 0.
However, these works are restricted to regression tasks and binary classification tasks. Meanwhile, these sample reweighting based works require a large sample reweighting matrix, and its parameter number is proportional to the training data number. Hence, these works are not feasible for tasks with massive training data. Different from them, our proposal addresses the model misspecification problem by rectifying the correlation between features. Moreover, the parameter number of our proposal is proportional to the feature dimension. In most machine-learning tasks, the feature dimension is less than thousands. Hence, our proposal is more feasible than previous works for tasks with a large training dataset.
2.2 Non-causality based Methods
In addition to causality based methods, a variety of domain adaptation (Chen et al., 2020; Liu and Ziebart, 2014; Zadrozny, 2004) and transfer learning methods Pan et al. (2018); Wang et al. (2019) were proposed to address the non-i.i.d problem. Most of these methods handle the distribution shift between training and testing datasets by aligning the training dataset to the target dataset or vice versa. To achieve that, these methods require prior knowledge of the distribution of the target domain. However, in the agnostic distribution shift problem, there is no prior knowledge about the test datasets. Hence, these methods cannot be applied to the agnostic distribution shift problem we focused on in this work.
Except for domain adaptation and transfer learning methods, there are also some domain generalization (Muandet et al., 2013) and invariant causal prediction (Peters et al., 2016) methods were proposed recently to address the distribution shift problem. These works exploring the invariant structure between predictors and the response variables in multiple training data sets to make prediction (Kuang et al., 2020). However, these works cannot handle the distribution shifts that are not observed in the training data.
3 Causality-based Feature Rectification
3.1 Problem Formulation and Preliminaries
Given a dataset collected in environment , where is the observed feature in environment , is the corresponded ground truths. The enviroment
is defined as a joint distributionon and is the collection of all environments. The stable learning task is to learn a preditive model on training dataset which can achieve uniformly small error across unknown environments in (Kuang et al., 2020).
In this work, we denote as the observed feature dimenstion, as the sample size. For a martrix , we let and represent the -th row and the -th column in respectively. and denote the remaining matrixs by removing the -th row and
-th column respectively. For a vector, we let and .
3.2 Causality Stable Feature
According to the stability of the causality between ground truths and features , we can divide into stable features and unstable features under Assumption 1 and Assumption 2:
(Kuang et al., 2020) There exists a stable function , so that in any environment we have .
Assumption 1 can be guaranteed by . So the model misspecification problem under agnostic distribution shift can be addressed by learning the stable function while ignoring the interference of unstable features. However, in real-world applications, there are no prior knowledge on which feature is stable and which is not. And there is no guarantee that no causality exits between features.
All stable feature are observed.
Under Assumption 2, the stable learning problem can be addressed by rectifying the correlation between unstable features and the omitted causality.
3.3 Linear Regression under Agnostic Distribution Shift
We use linear regression task with model misspecification as an example to illustrate our method in this work. We assume the true underlying model is:
where is the true stable causality function and obeys the Gauss–Markov assumptions (Henderson, 1975). Hence,.
We then suppose in an environment , finite samples are sampled from the environment and formed dataset . Due to the data generation bias, the distribution of is deviated from the true distribution . Part of the causal relationship between the stable features and the ground truths being omitted due to this agnostic distribution shift. If a simple linear regression is applied, it may trapped in the statistical correlation between unstable features and the omitted causality. This statistical correlation between unstable features and the omitted part is the reason of model misspecification (Kuang et al., 2020). The model under environment is given by:
where is the respond of the omitted causality, which statistical correlated to .
Since there is a strong correlation between and , then a proxy function can be learned to represented using in environment . Based on this idea, we propose the feature rectification regularizer.
3.4 Feature Rectification
We propose to learn a feature rectification weight matrix by reconstructing features using all other features. The weight matrix learns the correlations between features. By doing so, the omitted stable feature patterns correlated to the unstable features can be represented by the weight matrix. The objective function as following:
where is the feature dimension, , means all the remaining feature by removing -th feature, and means the all the remaining weights by removing -th column of weights.
Suppose a training dataset with sample size , we can denote the loss in Equation 3 as:
3.5 Application in Linear Regression
To illustrate our method, we using an OLS estimator togther with our proposal to estimate the regression coefficients:
As we discussed above, the weight matrix is trained by reconstructing the causality graph between features using the training dataset. The residual of the reconstruction is the independent part of the features which is independent from others. The OLS estimator is compelled to learn the residual.
We claim that the second term () helps the regressor to learn the true underlying model and improves the stability of the model across different test datasets. The experimental results have shown that our proposal worked when dealing with linear regression and image classification tasks.
The Loss function for optimizing Equation5 is:
where is the sample number of the training dataset. As for image classification tasks with deep-learning based models, the loss function is:
where denotes a training image sample, denotes the output of the feature extractor, denotes the output of the classifier and is a Cross-Entropy loss function.
3.5.1 Complexity Analysis
The main time costs for our proposal are to calculate the value of loss function and update parameters and . We used DWR (Kuang et al., 2020) as a reference to evaluate the complexity of our work. The computation complexity of Algorithm 1 is , where is the sample size and is the feature dimension. The complexity of the state-of-the-art method DWR is .
We conduct extensive experiments to evaluate our proposal, comparing with several baselines and the state-of-the-art works on both synthetic and real-world datasts.
4.1.1 Real-world Datasets
We use the benchmark dataset CIFAR-10 (Krizhevsky et al., 2009)
to evaluate the performance of our proposal in image classification tasks. CIFAR-10 is an established computer-vision dataset used for object recognition. It is a subset of the 80 million tiny images dataset and consists of 60,000 32x32 color images containing one of 10 object classes, with 6000 images per class. The size of image in CIFAR-10 is
pixels. During training, images are random cropped and horizontal flipped with a probability 0.5.
4.1.2 Synthetic Datasets
Under Assumption 1, there are three kinds of relationship between and , including , and . The denotes stable features and denotes unstable features. We construct synthetic datasets following the setting of previous state-of-the-art work DWR (Kuang et al., 2020). This synthetic dataset mimic the model misspecification problem in which part of the causality is omitted due to the data generation bias. The data generation bias causes a statistical correlations between unstable features and the omitted causality.
We mimic the situation in this synthetic dataset. We generate input features
with independent Gaussian distributions with the help of an auxiliary features:
where and is the number of stable features and unstable features respectively, denotes the -th feature in . The setting about and can be found in the supplementary materials.
To test the performance with different forms of missing nonlinear terms, we generate the ground truths from a polynomial nonlinear function () and an exponential function ():
where , and .
To test the stability, we generate a set of environments, each with a distinct joint distribution , while preserving . To achieve that, we generate environments by varying on a subset of . Following the setting used by DWR (Kuang et al., 2020), we vary via biased sample selection with a bias rate . For each sample, the probability of being selected is defined as , where . If , otherwise .
4.2 Evaluation Metrics
We use RMSE, , Average Error (AE) and Stability Error (SE) to evaluate the performance of our proposal. The between a learned coefficient and the true coefficient is defined as , where is feature dimension of . We report both mean and variance of of 50 independent experiments. The AE and SE proposed by Kuang et al. (2020) is define as:
4.3 Compared Methods
We use four methods as baselines in this work, including OLS, Lasso (Tibshirani, 1996)1970) and DWR (Kuang et al., 2020). DWR is the previous state-of-the-art causality-based sample reweighting method. We used the official implementation of DWR provided by the authors.
Our proposal has three advantages over the previous state-of-the-art works. First, the previous state-of-the-arts are restricted to the linear model and not to the deep nonlinear model. Our proposal can be used for both linear and depth models. Second, the parameter number of previous state-of-the-arts are proportional to the sample size of the training dataset, so it is not suitable for scenarios with a large training dataset. However, the parameter number of our proposal is invariant of sample size. Meanwhile, the performance of previous state-of-the-arts are sensitive to the sample size, but ours is not. Third, the previous state-of-the-arts require to calculated the loss of all training samples before optimizing parameters, which is both computational and memory intensive. Different from them, our proposal uses mini-batch training scheme, which is more practical.
4.4 Experiments on Synthetic Datasets
We evaluate the performance of all methods by comparing the accuracy on parameter estimation and stability on prediction across unknonw test datasets. To evaluate the parameter estimation accuracy, we train all models on the same training dataset with a specific bias rate . We repeat this training process for 50 times independently with different training data from the same bias rate , and report the mean and variance of on since the is the source of model misspecification in this synthetic dataset. To evaluate the prediction stability, we test all models on several test environments with various bias rate . For each test bias rate, we generate 50 different test datasets and report the mean of RMSE. Using the RMSE results, we further calculate the AE and SE to evaluate the prediction stability across different test environments. As for hyper-parameters, we use random seed 47 for all experiments. The init learning rate for optimizing the weight matrix of DWR and our proposal is 0.005, and 0.001 for optimizing the linear regressor.
|Scenario 1: varying sample size n|
|Scenario 2: varying feature dimension p|
|Scenario 3: varying bias rate r on training data|
As shown in Figure 1, we visualized the results with setting with . We can notice that our algorithm can achieves the lowest parameter estimation error on compare with all baselines. It shows that our proposal can significantly mitigate the model misspecification caused by feature correlations. Meanwhile, our proposal can achieve the lowest SE across different , which indicates the model trained by our proposal is more stable than all baselines including the state-of-the art method.
As shown in Table 1, our method achieved the lowest and SE in all experiments comparing with all baselines including state-of-the-art method DWR. It demonstrates that our proposal can effectively mitigate the model misspecification problem and improve the stability of the model across different test dataset. Meanwhile, we can notice that the performance of our proposal is stable when the sample size changing, but the performance of the state-of-the-art method DWR is affected by the sample size as shown in Table 1. It shows that our method is more robust when using different training sample size.
These observations lead to the conclusion that our method can achieve better stability across different test datasets. Moveover, the experimental results also demonstrate that our proposal is more stable when dealing with limited sample numbers.
The experimental results of , with , and results of are reported in the supplementary materials.
4.5 Experiments on Real-world Datasets
We then evaluate the performance of our proposal in image classification tasks. We conduct experiments on the benchmark image classification dataset CIFAR-10. We use a ResNet-50 model as the feature extractor to extract feature embedding of the input images, and we then apply a classifier consists of one linear layer to predict the label of the input image. Similarly as Algorithm 1, we use a feature rectification weight matrix to learn the causality graph of the extracted feature embedding. The loss function for optimizing the weight matrix have shown in Equation 7.
As for baselines, we compare our proposal with the naive ResNet-50 and the ResNet-50 with DWR (Kuang et al., 2020)
. Note that the original DWR is restricted to linear regression tasks, so we applied the key idea of DWR to the image classification tasks. The original loss function of DWR requires to calculate loss over all training samples and then carry out backpropagate, which is not feasible in image classification tasks with using GPU accelerating. As a trade-off, we update part of the sample weight of DWR in each mini-batch.
As for hyper-parameters, we using init learning rate of 0.1 for optimizing all models and 5e-6 for optimizing the weight matrix of our proposal. All models are trained for 350 epochs. The learning rate is reduced after the 150 and 160 epoch by multiplying 0.1. The random seed is 47 for all experiments.
The results have shown in Table 2. Our proposal can improve the accuracy of the deep-learning based model ResNet-50 in all datasets. These experiemnts domenstrate that our proposal can be used in both linear regression and image classification tasks.
In this work, we address the model misspecification problem under the agnostic distribution shift by proposing a novel Causality-based Feature Rectification (CFR) method. Experiments on both synthetic and real-world datasets demonstrate that our proposal helps improve the performance of the baseline models, and outperforms the state-of-the-art stable learning methods. Our method is a general data pretreatment method, which can be seamlessly integrated into classical linear regression models and classification models. It provides a unified approach to alleviate the problem of model misspecification problem under the agnostic distribution shift. Unlike previous works, our proposal is not restricted to linear models, and can also be applied to deep-learning based models.
Our proposal is a data pretreatment method which causes no explicit ethical problems. As for benefits, our proposal can help to improve the stability of the model on unknown test datasets, and reduce the data annotation costs for fine-tuning model into new scenarios. It can be applied to most deep-learning based models and linaer models in many application scenarios like Quantitative financial analysis, social media and time series based forecasting tasks. Moreover, our proposal helps to leverages biases in the data. Thus the likely beneficiaries are companies that have business suffers the model misspecification problem. The consequence of the failure of the proposal is a poor classification or regression performance. Our work may affect the interests of companies whose primary business is data annotations.
- Adversarial-learned loss for domain adaptation. arXiv preprint arXiv:2001.01046. Cited by: §1, §2.2.
- Best linear unbiased estimation and prediction under a selection model. Biometrics, pp. 423–447. Cited by: §3.3.
- Ridge regression: biased estimation for nonorthogonal problems. Technometrics 12 (1), pp. 55–67. Cited by: §4.3.
- Learning multiple layers of features from tiny images. Cited by: §4.1.1.
- Stable prediction with model misspecification and agnostic distribution shift. CoRR abs/2001.11713. External Links: Cited by: §1, §1, §2.1, §2.2, §3.1, §3.2, §3.3, §3.5.1, §4.1.2, §4.1.2, §4.2, §4.3, §4.5.
- Robust classification under sample selection bias. In Advances in neural information processing systems, pp. 37–45. Cited by: §2.2.
- Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. Cited by: §2.2.
- MacNet: transferring knowledge from machine comprehension to sequence-to-sequence models. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett (Eds.), pp. 6092–6102. External Links: Cited by: §1, §2.2.
Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5), pp. 947–1012. Cited by: §2.2.
- Causally regularized learning with agnostic data selection bias. In Proceedings of the 26th ACM international conference on Multimedia, pp. 411–419. Cited by: §1, §2.1.
- Stable learning via sample reweighting. arXiv preprint arXiv:1911.12580. Cited by: §1, §1, §2.1.
- Regression shrinkage and selection via the lasso. Journal of the Royal Statistical Society: Series B (Methodological) 58 (1), pp. 267–288. Cited by: §4.3.
- A minimax game for instance based selective transfer learning. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’19, New York, NY, USA, pp. 34–43. External Links: Cited by: §1, §2.2.
- Learning and evaluating classifiers under sample selection bias. In Proceedings of the twenty-first international conference on Machine learning, pp. 114. Cited by: §2.2.