In many real-world applications of machine learning we experience dataset shift, i.e. the data available in the training and inference stages come from different distributions. There has been increasing recent interest in developing machine learning models that are robust to such shifts.
In this paper we consider a scenario when the change in the data distribution can be fully explained by some known random variable. Let and denote the data distributions during training and testing stages. We assume they satisfy the following conditions:
We aim to learn a model to perform anticausal prediction (Schölkopf et al., 2012), classify given , and successfully generalize to unseen relationships between and . We assume we have access to , but have no access to at the training time. We call this setup class-dependent domain shift.
The following artificial and real-world examples demonstrate several details of such setups.
Colored handwritten digits. Let denote the digit, denote the color and denote the handwritten image. Consider the following generative process where we first pick the digit, then choose a color and then draw an image. During test time, we choose the digits from the same distribution, but for each digit we might choose a color from a different set. The goal is to learn a model that can predict the digit with unseen colors. Note that during testing we might encounter completely new colors, but we might also see a digit in a color that was only used for a different digit in the training phase.
Heart disease classification. Consider a heart disease with a fixed prevalence in some population. We sample a group of patients, some of which have the disease, perform electrocardiography and extract heartbeats from the ECG data. Let denote the heartbeats, denote the patient ID, and
denote the existence of the disease (a binary variable). The goal is to predict the disease from the heartbeats so that the model works for unseen patients.
2 Related work
(Quionero-Candela et al., 2009) has a comprehensive analysis and a nice graphical illustration of various dataset shift scenarios. Each scenario is represented by the plot of the underlying causal graphical model, where each node corresponds to one variable, and the nodes for which the distribution might change between training and test environments are highlighted with a darker color (Fig. 1).
Perhaps, the most widely explored type of dataset shift is called simple covariate shift
, when the joint distribution ofand is factorized as , is different between training and test environments, while is the same (Gretton et al., 2009) (Fig. 0(a)). A typical example is the prediction of future events given the current state. The distribution of the states can change over time, but the way they cause future events is stable. The following examples demonstrate that the problems described above do not always satisfy the definition of a simple covariate shift.
In the Colored handwritten digits problem, additionally assume that , , and . Now assume there is a red image which depicts some symbol “between” and . If the symbol is from the training set, it is a badly written , because there are no red s in the training set. Similarly, it is a if it is from the test set. Hence, we have , which violates the covariate shift assumption. An ideal model should not predict for only because s happened to be red in the training set.
In the Heart disease classification example assume that the heartbeats of one sick patient from the training set are quite similar to the heartbeats of another healthy individual from the test set. We cannot assume , as the heartbeat does not determine the existence of the disease. Instead, the disease affects the shape of the heartbeat, and there might be additional factors which might cause the heartbeats of healthy and sick patients to look similar.
In prior probability shift or label shift scenario (Lipton et al., 2018) we assume that , while (Fig. 0(b)). Eq. (3) clearly contradicts the first assumption. If we consider the concatenation of and as an extended label, then our setup becomes similar to label shift, where only one part of the label is changed across environments. The main difference though is that we are only interested in predicting the stable part of the label without having any constraints on the distribution of the other part of the label in the test set.
Our setup is closely related to domain shift. (Quionero-Candela et al., 2009) defines domain shift using a latent variable which corresponds to the underlying data and assumes the training algorithm has access only to the modified version of it, , where the modifier function changes between training and test environments. The goal is to learn a predictor which can generalize to new modifiers . In many scenarios can be interpreted as a measurement of (e.g. photograph of an object using an RGB vs. infrared camera). One difference compared to our setup is the direction of the causality between and . The other difference is that we allow the function to explicitly depend on . In the example with cameras this means that the choice of the camera might depend on the object category.
Another related scenario is called source component shift (Fig. 0(d)). The assumption here is that the data comes from different sources, each source has unique characteristics, and the contributions of different sources in training and test time are different. If denotes the random variable corresponding to the source, then the joint distribution is factorized as , where the first two factors are constant among environments.
Our setup is visualized in Fig. 0(e). The main difference from source component shift is that causes in our case. The other difference is the direction of the causality between and (or ). We believe the second difference is not critical, as in many scenarios (including in the colored handwritten digit example from the introduction) can be factorized in both directions.
Recently, (Arjovsky et al., 2019) proposed a new learning algorithm called Invariant risk minimization (IRM), which can demonstrate out-of-distribution generalization for a wide range of dataset shift scenarios. It is based on the concept of distinct training environments, where the data in each environment is sampled i.i.d. from its distribution, but the causal relationships between variables can vary across environments. Our experiments with the IRM codebase did not produce good results for our setup (see Section 4.4).
The methods developed for various dataset shift scenarios can be categorized into the ones which require access to the test distribution (without the labels) and ones which do not. Unsupervised domain adaptation algorithms (Ganin and Lempitsky, 2015) and most covariate shift adaptation methods (Sugiyama et al., 2007) are examples of the first category. Our setup assumes no access to the test distribution, similar to (Greenfeld and Shalit, 2019). This is also called zero-shot domain adaptation (Peng et al., 2018).
The methods that attempt to handle spurious correlations between variables can be categorized into the ones which require explicit annotations of those variables and ones which discover such correlations automatically. For example, in the space of algorithms designed to learn invariant representations, the method developed in (Moyer et al., 2018) is an example of the first category, while (Jaiswal et al., 2019) contains examples of the second category. We believe it is impossible to identify the spurious correlations without explicit annotations in our setup, so we assume the models do have access to the variable .
3 The proposed method
To obtain a classifier that will generalize to unseen we follow the representation learning approach. We propose to learn a representation of which is rich enough for predicting but has no information about except for the information that is shared between and . These ideas are formalized in the optimization problem under the constraint , where denotes the mutual information. This problem is relaxed to the following objective:
The first term is approximated using its variational lower bound (Alemi et al., 2017). For the second term, we note that , where is constant for the training set. Following (Lopez et al., 2018; Jaiswal et al., 2019), we use Hilbert-Schmidt Independence Criterion (Gretton et al., 2005) to minimize
. The resulting loss function becomes:
4 Experiments on colored MNIST
4.1 The dataset
Our dataset is based on MNIST images (LeCun and Cortes, 2010). For simplicity we take only images of digits 5, 6 and 9. We “color” our images the following way. First we uniformly sample and fix 12 colors. A color is defined as a
-dimensional vector from. We use two distinct colors per each digit for the training set and two others for the development set. Then, we add a third dimension to each image of size , repeat monochrome image values across the new dimension and then multiply the image by the color 111If is an image where white pixels are encoded by , we construct colored by taking . To visualize the digit we add back and interpret the first three channels as RGB colors.. We used .
We evaluate the models using three datasets (Fig. 2). The first one is called a quasi-development set, where the digits have the same colors as in the training set. The second set is called development set, where the images have completely different colors (the other 6 colors of the 12 fixed colors). And the last set is called adversarial development set, where images have the same colors as in training set, but the colors are assigned to different digits, i.e. the color used for 6s in the training set are used for 9s etc. Any classifier that depends on the color (in contrast to the shape of the symbol) will make incorrect predictions on this set.
4.2 Experimental setup
To show the efficiency of our approach we compare it with a basic neural baseline. The neural network consists of a fully convolutional encoder and a linear classifier. The encoder consists of three stacked convolutional layers, and transforms input into a representation of size . The only difference of our model from the basic baseline is an additional term in the final loss function, where is the encoded representation of the input, is the label and is the domain-dependent variable, which is the index of the “color” in our case. We also perform experiments using instead of .
We run all experiments for fixed 100 epochs with batch size 150. We tryand use learning rates and . The model selection is tricky, as we do not want to access the harder development sets in training time to approximate real-world zero-shot scenario.
4.3 Results and discussion
The typical behavior of the models is shown in Fig. 3. The simple baseline gets perfect accuracy on the quasi-development set, sometimes gets better-than-random accuracy on the development set, and always gets worse-than-random accuracy on the adversarial development set (Fig. 2(a)). This behavior is a sign of making a prediction based on the color and not on the shape of the symbol. Confusion matrices support this explanation.
The training process of our method can be described by two distinct phases (Fig. 2(b)). In the first phase, the model quickly learns to classify the digits with identical accuracy on all evaluation sets (up to 95%) and stays stable for some time. This phase is longer larger values of and smaller learning rates. In the second phase, the model starts to overfit on the colors used in the training set, accuracy on the quasi-development set reaches 100%, while on the other sets it decreases. When and LR=, the second phase does not even start in 100 epochs.
The modified version of our method with performs similarly, except that the accuracy does not rise right from the first epoch (Fig. 2(c)). It takes up to 30 epochs (faster with larger learning rates and smaller values of ) to reach 90%+ accuracy on all evaluation sets. A possible explanation is that the term tries to minimize the mutual information , while the softmax term attempts to maximize the same quantity. It takes time before the softmax term wins and drives the accuracy up. On the other hand, decreasing does not reduce . This explanation is supported by the fact that the loss term is increased towards the end of the first phase and slowly starts to decrease in the second phase (Fig. 2(d)).
Unfortunately, we were not able to find a reliable way to perform model selection. If we choose the checkpoint which has the best accuracy on the quasi-development set among all hyperparameters, we get an overfitted model with 100%, 79.77% and 34.31% accuracies on the three evaluation sets. If we choose the model according to the performance on the development set, then we get an optimal model with 96.01%, 96.39% and 95.93% accuracies.
4.4 Invariant Risk Minimization
IRM (Arjovsky et al., 2019) requires the training data to be split into distinct environments. The main assumption is that the dependence of from its causal parents is the same in all environments, while the rest of the causal graph can be changed. In our setup, has no causal parent, so IRM should be applicable (at least in theory). To adapt our colored MNIST dataset for IRM, we have to define environments for the training data. We kept two labels only to bring the setup closer to the dataset used in the original IRM paper. We tested two sets of environments.
As we have two colors per label, we need four environments to cover our training set while having a fixed color per label in each environment.
Another approach is to have two environments with all color combinations, but with different ratios of colors per label. We chose the ratio of the first and the second colors of each label to be 60:40 in the first environment and 40:60 in the second environment. This is perhaps closer to the experiments described in the original paper.
We performed experiments using the code provided by the authors. In all cases IRM did not perform better than the regular Empirical Risk Maximization (ERM) baseline (while the “greyscale” baseline worked perfectly). We additionally tried to replace the MLP used in the original code with a convolutional network, as our own method struggles to learn a generalizable model without convolutional layers, but did not see any improvement.
We would like to thank Hrayr Harutyunyan for useful discussions. This material is based on research sponsored by Air Force Research Laboratory (AFRL) under agreement number FA8750-19-1-1000. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation therein. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies or endorsements, either expressed or implied, of Air Force Laboratory, DARPA or the U.S. Government. T.G. and H.K. were partially supported by an EIF grant. The experiments were performed on Titan V GPUs donated to YerevaNN by NVIDIA.
- Deep variational information bottleneck. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, Cited by: §3.
- Invariant risk minimization. arXiv preprint arXiv:1907.02893. Cited by: Appendix A, §2, §4.4.
Unsupervised domain adaptation by backpropagation. In International Conference on Machine Learning, pp. 1180–1189. Cited by: §2.
- Robust learning with the hilbert-schmidt independence criterion. arXiv preprint arXiv:1910.00270. Cited by: §2.
- Measuring statistical dependence with hilbert-schmidt norms. In International conference on algorithmic learning theory, pp. 63–77. Cited by: §3.
- Covariate shift by kernel mean matching. Dataset shift in machine learning 3 (4), pp. 5. Cited by: §2.
- Discovery and separation of features for invariant representation learning. arXiv preprint arXiv:1912.00646. Cited by: §2, §3.
- MNIST handwritten digit database. Note: http://yann.lecun.com/exdb/mnist/ External Links: Cited by: §4.1.
- Detecting and correcting for label shift with black box predictors. In Proceedings of the 35th International Conference on Machine Learning, J. Dy and A. Krause (Eds.), Proceedings of Machine Learning Research, Vol. 80, Stockholmsmässan, Stockholm Sweden, pp. 3122–3130. Cited by: §2.
- Information constraints on auto-encoding variational bayes. In Advances in Neural Information Processing Systems, pp. 6114–6125. Cited by: §3.
- The cells out of sample (coos) dataset and benchmarks for measuring out-of-sample generalization of image classifiers. In Advances in Neural Information Processing Systems 32, pp. 1854–1862. Cited by: Appendix A.
- Invariant representations without adversarial training. In Advances in Neural Information Processing Systems 31, pp. 9084–9093. Cited by: §2.
Zero-shot deep domain adaptation.
Proceedings of the European Conference on Computer Vision (ECCV), pp. 764–781. Cited by: §2.
- Dataset shift in machine learning. The MIT Press. Cited by: Figure 1, §2, §2.
- On causal and anticausal learning. In Proceedings of the 29th International Conference on Machine Learning, ICML 2012, Edinburgh, Scotland, UK, June 26 - July 1, 2012, Cited by: §1.
- Covariate shift adaptation by importance weighted cross validation. Journal of Machine Learning Research 8 (May), pp. 985–1005. Cited by: §2.
- Unshuffling data for improved generalization. arXiv preprint arXiv:2002.11894. Cited by: Appendix A.
Appendix A Other datasets
Many real-world datasets contain samples from similar but distinct domains and sources. As most of the machine learning research is focused on the case when the samples in both the training and test sets are i.i.d., the datasets are manually shuffled and the domain-specific information is erased. As noted in (Arjovsky et al., 2019), the original NIST handwritten data was collected from different writers under different conditions, but the MNIST dataset is not only shuffled, but also the information about the writers is removed. Unfortunately, this trend continues nowadays, which limits the development of robust classification methods that rely on .
The recently proposed Cells Out of Sample (COOS) dataset contains microscope images of mouse cells (Lu et al., 2019). The cells are captured from different wells on multiple plates on different days over two years. The label is the protein which is used to highlight various parts of the cells. The paper proposes a benchmark for image classification under covariate shift. The authors prepared four test sets with varying degree of dataset shift. In fact, the way the dataset is prepared implies that the class label causes , and Eq. (1)-(5) describe the dataset better than the covariate shift. The original version of the dataset did not include annotations for image date, plate and well (which together can form our variable ). The authors of the dataset kindly agreed to add this information in a new release, and we will try our method on the dataset in future work.
In (Teney et al., 2020), the authors attempt to “unshuffle” the popular visual question answering dataset VQA to obtain multiple data domains or environments. They implement a practical method inspired by Invariant Risk Minimization and demonstrate performance improvements compared to the i.i.d. baselines. We left experiments using our method on this dataset for future work.