1 Introduction
To deploy machine learning algorithms in realworld applications, we must pay attention to distribution shift, i.e. when the test distribution is different from the training distribution, which substantially degrades model performance. In this paper, we refer this problem as outofdistribution (OOD) generalization and specifically consider performance gaps caused by two kinds of distribution shifts: domain shifts and subpopulation shifts. In domain shifts, the test data is sampled from different domains than the training data, which requires the trained model to generalize well to test domains without seeing the data from those domains in training time. Take health risk prediction as an example. We may want to train a model on patients from a few sampled hospitals and then deploy the model to a broader set of hospitals (Koh et al., 2021). In subpopulation shifts, the proportions of subpopulations in the test distribution differ from the proportions in the training distribution (eg. the setting of imbalanced data). When subpopulation shift occurs, models perform poorly when they falsely rely on spurious correlations, which may occur when some subpopulations are underrepresented in the training set. For example, in financial risk prediction, a machine learning model trained on the entire population may associate the labels with demographic features (e.g., religion and race), making the model fail on the test set when such an association does not hold in reality.
To improve model robustness under these two kinds of distribution shifts, methods for learning invariance have shown effectiveness in various applications. These methods use regularizers to learn features or predictors that are invariant across different domains while still containing sufficient information of fulfilling the task (Li et al., 2018; Sun & Saenko, 2016; Arjovsky et al., 2019; Krueger et al., 2021; Rosenfeld et al., 2021). However, designing regularizers that are widely suitable to datasets from diverse domains is especially challenging and insuitable regularizers may adversely limit the model’s expressive power or yield a difficult optimization problem, leading to inconsistent performance among various realworld datasets. For example, on the WILDS datasets, invariant risk minimization (IRM) (Arjovsky et al., 2019) with reweighting outperforms empirical risk minimization (ERM) on CivilComments, but fails to improve robustness on a variety of other datasets like Camelyon17 and RxRx1 (Koh et al., 2021). CORAL (Sun & Saenko, 2016), a popular domain adaptation method that adds explicit regularization, has also been observed to have similar phenomenon (Koh et al., 2021).
Instead of explicitly imposing regularization to learn invariant representations or predictors, we turn towards an implicit solution in this paper. Inspired by mixup (Zhang et al., 2018), we aim to learn invariant functions through data interpolation, leading to a simple algorithm called LISA (Learning Invariant Functions with Selective Augmentation). Concretely, LISA linearly interpolates the features for a pair of samples and applies the same interpolation strategy on the corresponding labels. Critically, the pairs are selectively chosen according to two selective augmentation strategies – intralabel LISA (LISAL) and intradomain LISA (LISAD), which are described below and we illustrate the motivations of both strategies on Colored MNIST dataset in Figure 1. Intralabel LISA (Figure 1(b)) interpolates samples with the same label but from different domains, aiming to eliminate domainrelated spurious correlations. Intradomain LISA (Figure 1(c)) interpolates samples with the same domain but different labels, where the model should to ignore the domain information and generate different predicted values as the interpolation ratio changes. In this way, LISA encourages the model to learn domaininvariant functions without any explicitly constraints or regularizers.
The primary contributions of this paper are as follows: (1) We propose a simple yet widelyapplicable method for learning domain invariant functions that are shown to be robust to domain shift and subpopulation shift. (2) We conduct broad empirical experiments to evaluate the effectiveness of LISA on nine benchmark datasets from diverse domains. In these experiments, we make the following observations. First, we find that LISA consistently outperforms seven prior methods in addressing both domain shifts and subpopulation shifts. Second, we identify that the performance gains of LISA are indeed caused by canceling out domainspecific information (or spurious correlations) and learning invariant functions, rather than simply involving more data via interpolation. Third, when the degree of distribution shift increases, LISA achieves more significant performance gains. (3) Finally, we provide theoretical analysis of the phenomena distilled from the empirical studies, where we provably demonstrate that LISA can mitigate the spurious correlations and therefore lead to smaller worstdomain error compared with ERM and vanilla mixup. We also note that to the best of our knowledge, our work provides the first theoretical framework of studying how mixup (with or without the selective augmentation strategies) affects misclassification error.
2 Preliminaries
In this paper, we consider the setting where one predicts the label based on the input feature . Given a parameter space
and a loss function
, we need to train a model under the training distribution , where . In empirical risk minimization (ERM), the empirical distribution over the training data is ; ERM optimizes the following objective:(1) 
In a traditional machine learning setting, a test set, sampled from a test distribution , is used to evaluate the generalization of the trained model , where the test distribution is assumed to be the same as the training distribution, i.e., . In this paper, we are interested in the setting when distribution shift occurs, i.e., .
Specifically, following Muandet et al. (2013); Albuquerque et al. (2019); Koh et al. (2021), we regard the overall data distribution containing domains and each domain is associated with a data distribution over a set , where is the number of samples in domain . Then, we formulate the training distribution as the mixture of domains, i.e., , where
denotes the mixture probabilities in training set. Here, the training domains are defined as
. Similarly, the test distribution could be represented as , where is the mixture probabilities in test set. The test domains are defined as .In domain shift scenarios, we investigate the problem where the test domains are disjoint from the training domains, i.e., . In general, we assume the test domains share some common properties with the training domains. For example, in Camelyon17 (Koh et al., 2021), we train the model on some hospitals and test it in a new hospital. We evaluate the worstdomain and average performance of the classifier across all test domains.
In subpopulation shift scenarios, the test set has domains that have been seen in the training set, but with a different proportion of subpopulations, i.e., but . Under this setting, following Sagawa et al. (2020a), we specially consider groupbased spurious correlations, where each group is defined to be associated with a domain and a label , i.e., . We assume that the domain is spuriously correlated with the label. For example, we illustrate the CMNIST dataset in Figure 1, where the digit color (green or red) is spuriously correlated with the label ([1, 0] or [0, 1]). Based on the group definition, we evaluate the model via the worst test group error, i.e., , where represents the 01 loss.
3 Learning Invariant Functions with Selective Augmentation
This section presents LISA, a simple way to improve robustness to domain shifts or subpopulation shifts. The key idea behind LISA is to encourage the model to learn invariant functions by selective data interpolation, which could also alleviates the effects of domainrelated spurious correlations. Before detailing how to select interpolated samples, we first provide a general formulation for data interpolation.
In LISA, we perform linear interpolation between training samples. Specifically, given samples and drawn from domains and , we apply mixup (Zhang et al., 2018), a simple data interpolation strategy, separately on the input features and corresponding labels as:
(2) 
where the interpolation ratio is sampled from a Beta distribution and and
are onehot vectors for classification problem.
Notice that the mixup approach in equation 2 can be replaced by CutMix (Yun et al., 2019), which shows stronger empirical performance in visionbased applications. In textbased applications, we can use Manifold Mixup (Verma et al., 2019), interpolating the representations of a pretrained model, e.g., the output of BERT (Devlin et al., 2019).
After obtaining the interpolated features and labels, we replace the original features and labels in ERM with the interpolated ones. Then, the optimization process in equation 1 is reformulated as:
(3) 
Without additional selective augmentation strategies, vanilla mixup will regularize the model and reduce overfitting (Zhang et al., 2021b), allowing it to attain good indistribution generalization. However, vanilla mixup may not be able to cancel out spurious correlations, causing the model to still fail at attaining good OOD generalization (see empirical comparisons in Section 4.3 and theoretical discussion in Section 5). In LISA, we instead adopt a new strategy where mixup is only applied across specific domains or groups, which leans towards learning invariant functions and thus better OOD performance. Specifically, the two kinds of selective augmentation strategies are presented as follows:
Intralabel LISA (LISAL): Interpolating samples with the same label. Intralabel LISA interpolates samples with the same label but different domains (i.e., , ). As shown in Figure 1(a), this produces datapoints that have both domains partially present, effectively eliminating spurious correlations between domain and label in cases where the pair of domains correlate differently with the label. Additionally, if domain information does not fully reflect the spurious correlations in some datasets, we can also enlarge the interpolation scope to cover more potentially spurious correlations by only interpolating samples within the same class regardless of domain information (i.e., ). As a result, intralabel LISA should learn domaininvariant functions for each class and thus achieve better OOD robustness.
Intradomain LISA (LISAD): Interpolating samples with the same domain. Supposing domain information is highly spuriously correlated with the label information, intradomain LISA (Figure 1(b)) applies the interpolation strategy on samples with the same domain but different labels, i.e., , . Intuitively, even within the same domain, the model is supposed to generate different predicted labels since the interpolation ratio is randomly sampled, corresponding to different labels . This causes the model to make predictions that are less dependent on the domain, again improving OOD robustness.
For each batch of data, we randomly perform intralabel or intradomain LISA during the training process with probability and , where
is treated as a hyperparameter in our experiments. The choice of
depends on the number of domains and the relation between domain information and spurious correlations. Empirically, using intralabel LISA brings much more benefits when there are more domains, and/or the domain information can not fully reflect the spurious correlations. Intradomain LISA can benefit the performance when domain information is highly spuriously correlated with the label, where we find a balanced ratio (i.e., ) performs the best. The pseudocode of the training procedure of LISA is shown in Algorithm 1.4 Experiments
In this section, we conduct comprehensive experiments to evaluate the effectiveness of LISA. Specifically, we aim to answer the following questions: Q1: Compared to prior methods, can LISA improve robustness to domain shifts and subpopulation shifts (Section 4.1 and Section 4.2)? Q2: Which aspects of LISA are the most important for improving robustness (Section 4.3)? Q3: How does LISA perform with varying degrees of distribution shifts (Section 4.4)? Q4: Does LISA successfully produce invariant functions (Section 4.5)?
To answer Q1, we compare to ERM, IRM (Arjovsky et al., 2019), MMD (Li et al., 2018), DRNN (Ganin & Lempitsky, 2015), GroupDRO (Sagawa et al., 2020a), DomainMix (Xu et al., 2020), and Fish (Shi et al., 2021). Upweighting (UW) is particularly suitable for subpopulation shifts, so we also use it for comparison. We adopt the same model architectures for all approaches.
4.1 Evaluating Robustness to Domain Shifts
Experimental Setup. To study domain shifts, we study five datasets. Four datasets (Camelyon17, FMoW, RxRx1, and Amazon) are selected from WILDS (Koh et al., 2021)
, covering realworld distribution shifts across diverse domains (e.g., health, natural language process, and vision). Besides the WILDS data, we also apply LISA on the MetaShift datasets
(Liang & Zou, 2021), constructed using the realworld images and natural heterogeneity of Visual Genome (Krishna et al., 2016). We summarize the datasets in Table 7 of Appendix A.1.1, including domain information, evaluation metric, model architecture, and the number of classes. Detailed dataset descriptions and other training details are discussed in Appendix
A.1.1 and A.1.2, respectively.Results. We report the results of domain shifts in Table 1, where the full results that include validation performance and other metrics are listed in Appendix A.1.3. The optimal strategy selection probability is set as for these domain shifts problems, i.e., only intralabel LISA is used. In addition, we only interpolate samples with the same labels regardless of the domain information, which empirically leads to the best performance. According to Table 1, we have two key findings:
(1) There are no significant performance differences between ERM and other invariant learning methods (e.g., IRM, CORAL, DomainMix) on most datasets, which is consistent with the reported results on WILDS (Koh et al., 2021). A potential reason is that the existing domain information may not fully reflects the spurious correlation. For example, in Camelyon17wilds dataset, the presence of tumor tissue (i.e., label) mainly depends on the demographic of patients (e.g., race, gender), which shows no significant difference across hospitals (i.e., domain information). This could also explain why only interpolating samples with the same labels regardless the domain information achieves the best performance, which also corroborates our claims in the description of intralabel LISA in Section 3.
(2) The consistent superiority of LISA outperforms prior methods on all five datasets regardless of the model architecture and dataset types (i.e., image or text), demonstrating its effectiveness in improving OOD robustness by canceling out both observable and unobservable domainrelated correlations with selective augmentation.
Camelyon17  FMoW  RxRx1  Amazon  MetaShift  
Avg. Acc.  Worst Acc.  Avg. Acc.  10th Per. Acc.  Worst Acc.  
ERM  70.3 6.4%  32.3 1.25%  29.9 0.4%  53.8 0.8%  52.1 0.4% 
IRM  64.2 8.1%  30.0 1.37%  8.2 1.1%  52.4 0.8%  51.8 0.8% 
CORAL  59.5 7.7%  31.7 1.24%  28.4 0.3%  52.9 0.8%  47.6 1.9% 
GroupDRO  68.4 7.3%  30.8 0.81%  23.0 0.3%  53.3 0.0%  51.9 0.7% 
DomainMix  69.7 5.5%  34.2 0.76%  30.8 0.4%  53.3 0.0%  51.3 0.5% 
Fish  74.7 7.1%  34.6 0.18%  10.1 1.5%  53.3 0.0%  49.2 2.1% 
LISA (ours)  77.1 6.5%  35.5 0.65%  31.9 0.8%  54.7 0.0%  54.2 0.7% 
CMNIST  Waterbirds  CelebA  CivilComments  
Avg.  Worst  Avg.  Worst  Avg.  Worst  Avg.  Worst  
ERM  27.8%  0.0%  97.0%  63.7%  94.9%  47.8%  92.2%  56.0% 
UW  72.2%  66.0%  95.1%  88.0%  92.9%  83.3%  89.8%  69.2% 
IRM  72.1%  70.3%  87.5%  75.6%  94.0%  77.8%  88.8%  66.3% 
CORAL  71.8%  69.5%  90.3%  79.8%  93.8%  76.9%  88.7%  65.6% 
GroupDRO  72.3%  68.6%  91.8%  90.6%  92.1%  87.2%  89.9%  70.0% 
DomainMix  51.4%  48.0%  76.4%  53.0%  93.4%  65.6%  90.9%  63.6% 
Fish  46.9%  35.6%  85.6%  64.0%  93.1%  61.2%  89.8%  71.1% 
LISA (ours)  74.0%  73.3%  91.8%  89.2%  92.4%  89.3%  89.2%  72.6% 
Results of subpopulation shifts. Here, we show the average and worst group accuracy. We repeat the experiments three times and put full results with standard deviation in Table
15.4.2 Evaluating Robustness to Subpopulation Shifts
Evaluation Protocol. In subpopulation shifts, we evaluate the performance on four binary classification datasets, including Colored MNIST (CMNIST), Waterbirds (Sagawa et al., 2020a), CelebA (Liu et al., 2015), and Civilcomments (Borkan et al., 2019). We summarize brief data statistics in Table 13 of Appendix A.2.1, covering domain information, model architecture, and class information. Full dataset descriptions of subpopulation shifts are also presented in Appendix A.2.1. Following Sagawa et al. (2020a), in subpopulation shifts, we use the worstgroup accuracy to evaluate the performance of all approaches and the domain identifications are highly spurious correlated with the label information. For example, as suggested in Figure 1, 90% images in the CMNIST dataset have the same color and class, i.e., green color for label 0 and red color for label 1. Full hyperparameter settings and training details are listed in Appendix A.2.2.
Results. In Table 2, we report the overall performance of LISA and other methods. Similar to the observations in domain shifts, LISA consistently outperforms prior methods in CMNIST, CelebA, and CivilComments. In Waterbirds, LISA outperforms other invariant learning methods (e.g., IRM, CORAL, DomainMix, Fish) and shows similar performance to GroupDRO. These results demonstrate the effectiveness of LISA in improving OOD robustness. In CMNIST, Waterbirds, and CelebA, we find that works well for choosing selective augmentation strategies, while is set as in CivilComments. This is not surprising because it might be more beneficial to use intralabel LISA more often to eliminate domain effects when there are more domains, i.e., eight domains in CivilComments v.s. two domains in others.
4.3 Ablation Study: Is the Performance Gain from Data Augmentation?
Camelyon17  FMoW  RxRx1  Amazon  MetaShift  
Avg. Acc.  Worst Acc.  Avg. Acc.  10th Per. Acc.  Worst Acc.  
ERM  70.3 6.4%  32.8 0.45%  29.9 0.4%  53.8 0.8%  52.1 0.4% 
Vanilla mixup  71.2 5.3%  34.2 0.45%  26.5 0.5%  53.3 0.0%  51.3 0.7% 
Ingroup mixup  75.5 6.7%  32.2 1.18%  24.4 0.2%  53.8 0.6%  52.7 0.5% 
LISA (ours)  77.1 6.5%  35.5 0.65%  31.9 0.8%  54.7 0.0%  54.2 0.7% 
CMNIST  Waterbirds  CelebA  CivilComments  
Avg.  Worst  Avg.  Worst  Avg.  Worst  Avg.  Worst  
ERM  27.8%  0.0%  97.0%  63.7%  94.9%  47.8%  92.2%  56.0% 
Vanilla mixup  32.6%  3.1%  81.0%  56.2%  95.8%  46.4%  90.8%  67.2% 
Vanilla mixup + UW  72.2%  71.8%  92.1%  85.6%  91.5%  88.0%  87.8%  66.1% 
Ingroup mixup  33.6%  24.0%  88.7%  68.0%  95.2%  58.3%  90.8%  69.2% 
Ingroup mixup + UW  72.6%  71.6%  91.4%  87.1%  92.4%  87.8%  84.8%  69.3% 
LISA (ours)  74.0%  73.3%  91.8%  89.2%  92.4%  89.3%  89.2%  72.6% 
In LISA, we apply selective augmentation strategies on samples either with the same label but different domains or with the same domain but different labels. Here, we explore two substitute interpolation strategies: (1) Vanilla mixup: In Vanilla mixup, we do not add any constraint on the sample selection, i.e., the mixup is performed on any pairs of samples; (2) Ingroup mixup: This strategy applies data interpolation on samples with the same labels and from the same domains. Notice that all the substitute interpolation strategies use the same mixup types (e.g., mixup/Manifold Mixup/CutMix) as LISA. Finally, since upweighting (UW) small groups significantly improves performance in subpopulation shifts, we also evaluate UW combined with Vanilla/Ingroup mixup.
The ablation results of domain shifts and subpopulation shifts are in Table 3 and Table 4, respectively. Furthermore, we also conduct experiments on datasets without spurious correlation in Table 17 of Appendix A.3. From the results, we make the following three key observations. First, compared with Vanilla mixup, the performance of LISA verifies that selective data interpolation does improve the outofdistribution robustness by canceling out the spurious correlations and encouraging learning invariant functions rather than simply data augmentation. This findings are further strengthened by the results in Table 17 of Appendix A.3, where Vanilla mixup outperforms LISA and ERM without spurious correlations but LISA achieves the best performance with spurious correlations. Second, the superiority of LISA over Ingroup mixup verifies that only interpolating samples within each group is incapable of eliminating out the spurious information, where Ingroup mixup still performs the role of data augmentation. Finally, though incorporating UW significantly improves the performance of Vanilla mixup and Ingroup mixup in subpopulation shifts, LISA still achieves larger benefits than these enhanced substitute strategies, demonstrating its stronger power in improving OOD robustness.
4.4 Effect of the Degree of Distribution Shifts
We further investigate the performance of LISA with respect to the degree of distribution shifts. Here, we use MetaShift to evaluate performance, where the distance between training and test domains is measured as the node similarity on a metagraph (Liang & Zou, 2021). To vary the distance between training and test domains, we change the backgrounds of training objects (see full experimental details in Appendix A.1.1). The performance with varied distances is illustrated in Table 5, where the top four best methods (i.e., ERM, GroupDRO, IRM, DomainMix) are reported for comparison. We observe that LISA consistently outperforms other methods under all scenarios. Additionally, another interesting finding is that LISA achieves more substantial improvements with the increases of distance. A potential reason is that the models may rely more heavily on domain correlations when there is a larger distance between training and test domains.
Distance  0.44  0.71  1.12  1.43 
ERM  80.1%  68.4%  52.1%  33.2% 
IRM  79.5%  67.4%  51.8%  32.0% 
DomainMix  76.0%  63.7%  51.3%  30.8% 
GroupDRO  77.0%  68.9%  51.9%  34.2% 
LISA (ours)  81.3%  69.7%  54.2%  37.5% 
4.5 Analysis about Learned Invariance
CMNIST  Waterbirds  MetaShift  CMNIST  Waterbirds  MetaShift  
ERM  12.0486  0.2456  1.8824  6.2858  1.8878  1.2051 
Vanilla mixup  0.2769  0.1465  0.2659  4.7365  2.9121  1.1711 
IRM  0.0112  0.1243  0.8748  7.7549  1.1219  1.1483 
DomainMix  0.1674  0.0995  1.1158  5.1600  2.7467  1.2019 
LISA (ours)  0.0012  0.0016  0.2387  0.5670  0.1344  1.0005 
Results of the analysis of functionlevel invariance. Variance of test risks across all domains (
) and pairwise KL divergence of prediction among all domains () are used to measure the invariance. Smaller values denote stronger invariance.Finally, we analyze the functionlevel invariance learned by LISA. Specifically, we use two metrics to measure the functionlevel invariance: (1) Variance of test risks (). Motivated by Arjovsky et al. (2019); Krueger et al. (2021), the functionlevel invariance is first measured by the variance of test risks across all domains, which is defined as , where represents the number of test domains and represents the risk of domain ; (2) Pairwise divergence of prediction (). We further measure the KL divergence of the predicted probability among all domains, i.e., . Small values of and represent strong functionlevel invariance. The results on CMNIST, Waterbirds and MetaShift are reported in Table 6, where the superiority of LISA verifies that it could also improve the functionlevel invariance. Besides the functionlevel invariance, we observe that LISA can also leads to stronger representationlevel invariance, which are detailed in Appendix A.4.
5 Theoretical Analysis
In this section, we provide some theoretical understandings that explain several of the empirical phenomena from the previous experiments and theoretically compare the worstgroup errors of three methods: the proposed LISA, ERM, and vanilla mixup. Specifically, we consider a Gaussian mixture model with subpopulation and domain shifts, which has been widely adopted in theory to shed light upon complex machine learning phenomenon such as in
(Montanari et al., 2019; Zhang et al., 2021c). We also note here that despite the popularity of mixup in practice, the theoretical analysis of how mixup (with or without the selective augmentation strategies) affects the misclassification error is still largely unexplored in the literature even in the simple models. As discussed in Section 2, here, we define as the label, and as the domain information. For and , we consider the following model:(4) 
where is the conditional mean vector and is the covariance matrix. Let . Let , , and .
To account for the spurious correlation brought by domains, we consider in general for and the imbalanced case where . Moreover, we assume there exists some invariance across different domains. Specifically, we assume
According to the theory of Fisher’s linear discriminant analysis (Anderson, 1962), the optimal classification rule is linear with slope . The assumption above implies that is the (unknown) invariant prediction rule for model equation 4.
Suppose we use some method and obtain a linear classifier from a training data, we will apply it to a test data and compute the worstgroup misclassification error, where the misclassification error for domain and class is , and we denote the worstgroup error with the method as
where and are the slope and intercept based on the method . Specifically, denotes the ERM method (by minimizing the sum of squares loss on the training data altogether), denotes the vanilla mixup method (without any selective augmentation strategy), and denotes the mixup strategy for LISA. We also denote its finite sample version by .
Let denote the marginal difference and denote the correlation operator between the domainspecific difference and the marginal difference with respect to . We see that smaller indicates larger discrepancy between the marginal difference and the domainspecific difference and therefore implies stronger spurious correlation between the domains and labels. We present the following theorem showing that our proposed LISA algorithm outperforms the ERM and vanilla mixup in the subpopulation shifts setting.
Theorem 1 (Error comparison with subpopulation shifts)
Consider independent samples generated from model (4), , , , and is positive definite. Suppose satisfies that for some large enough constant and . Then for any ,
Theorem 1 implies that when is small (indicating that the domain has strong spurious correlation with the label) and , the worstgroup classification errors of LISA are asymptotically smaller than that of ERM and vanilla mixup. In fact, our analysis shows that LISA yields a classification rule closer to the invariant classification rules by leveraging the domain information.
In the next theorem, we present the misclassification error comparisons with domain shifts. That is, consider samples from a new unseen domain:
Let , where is the mean of the training distribution, and assume . Let and denote the correlation for and for , respectively, with respect to . Let and its sample version be .
Theorem 2 (Error comparison with domain shifts)
Suppose samples are independently generated from model (4), , and is positive definite. Suppose that satisfy that and for some large enough constant and . Then for any ,
Similar to Theorem 1, this result shows that when domain has strong spurious correlation with the label (corresponding to small ), such a spurious correlation leads to the downgraded performance of ERM and vanilla mixup, while our proposed LISA method is able to mitigate such an issue by selective data interpolation. Proofs of Theorem 1 and Theorem 2 are provided in Appendix B.
6 Related Work and Discussion
In this paper, we focus on improving the robustness of machine learning models to domain shifts and subpopulation shifts. Here, we discuss related approaches from the following three categories:
Learning Invariant Representations. Motivated by unsupervised domain adaptation (BenDavid et al., 2010; Ganin et al., 2016), the first category of works learns invariant representations by aligning representations across domains. The major research line of this category aims to eliminate the domain dependency by minimizing the divergence of feature distributions with different distance metrics, e.g., maximum mean discrepancy (Tzeng et al., 2014; Long et al., 2015), an adversarial loss (Ganin et al., 2016; Li et al., 2018), Wassertein distance (Zhou et al., 2020a). Followup works applied data augmentation to (1) generate more domains and enhance the consistency of representations during training (Yue et al., 2019; Zhou et al., 2020b; Xu et al., 2020; Yan et al., 2020; Shu et al., 2021; Wang et al., 2020) or (2) generate new domains in an adversarial way to imitate the challenging domains without using training domain information (Zhao et al., 2020; Qiao et al., 2020; Volpi et al., 2018). Unlike these latter methods, LISA instead focuses on learning invariant functions without restricting the representations, leading to stronger empirical performance.
Learning Invariant Predictors. Beyond using domain alignment to learning invariant representations, recent work aims to further enhance the correlations between the invariant representations and the labels (Koyama & Yamaguchi, 2020), leading to invariant predictors. Representatively, motivated by casual inference, invariant risk minimization (IRM) (Arjovsky et al., 2019) aims to find a predictor that performs well across all domains. After IRM, the following works propose stronger regularizers by penalizing the variance of risks across all domains (Krueger et al., 2021), by aligning the gradient across domains (Koyama & Yamaguchi, 2020), by smoothing the crossdomain interpolation paths (Chuang & Mroueh, 2021), or through gametheoretic invariant rationalization criterion (Chang et al., 2020). Instead of using regularization, LISA eliminates spurious correlations and learns invariant functions in the data directly via data interpolation.
Group Robustness. The last category of methods combats spurious correlations and are particularly suitable for subpopulation shifts. These approaches include directly optimizing the worstgroup performance with Distributionally Robust Optimization (Sagawa et al., 2020a; Zhang et al., 2021a; Zhou et al., 2021), generating samples around the minority groups (Goel et al., 2021), and balancing the majority and minority groups via reweighting (Sagawa et al., 2020b) or regularizing (Cao et al., 2019, 2020). Here, LISA proposes a more general strategy based that is suitable for both domain shifts and subpopulation shifts.
7 Conclusion
To tackle the distribution shifts, we propose LISA, a simple and efficient algorithm, to improve the outofdistribution robustness. LISA aims to eliminate the domainrelated spurious correlations among the training set by selective sample interpolation. We evaluate the effectiveness of LISA on nine datasets under subpopulation shifts and domain shifts settings, demonstrating its promise. Besides, our detailed analysis verifies that the performance gains caused by LISA result from encouraging learning invariant functions and representations. Our theoretical results further strengthen the superiority of LISA by showing smaller worstgroup misclassification error compared with ERM and vanilla data interpolation.
References
 Albuquerque et al. (2019) Albuquerque, I., Monteiro, J., Darvishi, M., Falk, T. H., and Mitliagkas, I. Generalizing to unseen domains via distribution matching. arXiv preprint arXiv:1911.00804, 2019.
 Anderson (1962) Anderson, T. W. An introduction to multivariate statistical analysis. Technical report, Wiley New York, 1962.
 Arjovsky et al. (2019) Arjovsky, M., Bottou, L., Gulrajani, I., and LopezPaz, D. Invariant risk minimization. arXiv preprint arXiv:1907.02893, 2019.
 Bandi et al. (2018) Bandi, P., Geessink, O., Manson, Q., Van Dijk, M., Balkenhol, M., Hermsen, M., Bejnordi, B. E., Lee, B., Paeng, K., Zhong, A., et al. From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge. IEEE Transactions on Medical Imaging, 2018.
 BenDavid et al. (2010) BenDavid, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., and Vaughan, J. W. A theory of learning from different domains. Machine learning, 79(1):151–175, 2010.
 Borkan et al. (2019) Borkan, D., Dixon, L., Sorensen, J., Thain, N., and Vasserman, L. Nuanced metrics for measuring unintended bias with real data for text classification. In Companion proceedings of the 2019 world wide web conference, pp. 491–500, 2019.
 Cao et al. (2019) Cao, K., Wei, C., Gaidon, A., Arechiga, N., and Ma, T. Learning imbalanced datasets with labeldistributionaware margin loss. NeurIPS, 2019.
 Cao et al. (2020) Cao, K., Chen, Y., Lu, J., Arechiga, N., Gaidon, A., and Ma, T. Heteroskedastic and imbalanced deep learning with adaptive regularization. arXiv preprint arXiv:2006.15766, 2020.
 Chang et al. (2020) Chang, S., Zhang, Y., Yu, M., and Jaakkola, T. Invariant rationalization. In International Conference on Machine Learning, pp. 1448–1458. PMLR, 2020.

Christie et al. (2018)
Christie, G., Fendley, N., Wilson, J., and Mukherjee, R.
Functional map of the world.
In
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition
, 2018.  Chuang & Mroueh (2021) Chuang, C.Y. and Mroueh, Y. Fair mixup: Fairness via interpolation. ICLR, 2021.
 Devlin et al. (2019) Devlin, J., Chang, M.W., Lee, K., and Toutanova, K. Bert: Pretraining of deep bidirectional transformers for language understanding. 2019.

Ganin & Lempitsky (2015)
Ganin, Y. and Lempitsky, V.
Unsupervised domain adaptation by backpropagation.
In International conference on machine learning, pp. 1180–1189. PMLR, 2015. 
Ganin et al. (2016)
Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette,
F., Marchand, M., and Lempitsky, V.
Domainadversarial training of neural networks.
The journal of machine learning research, 17(1):2096–2030, 2016.  Goel et al. (2021) Goel, K., Gu, A., Li, Y., and Ré, C. Model patching: Closing the subgroup performance gap with data augmentation. In ICLR, 2021.
 He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778, 2016.
 Huang et al. (2017) Huang, G., Liu, Z., Van Der Maaten, L., and Weinberger, K. Q. Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4700–4708, 2017.
 Koh et al. (2021) Koh, P. W., Sagawa, S., Xie, S. M., Zhang, M., Balsubramani, A., Hu, W., Yasunaga, M., Phillips, R. L., Gao, I., Lee, T., et al. Wilds: A benchmark of inthewild distribution shifts. In International Conference on Machine Learning, pp. 5637–5664. PMLR, 2021.
 Koyama & Yamaguchi (2020) Koyama, M. and Yamaguchi, S. Outofdistribution generalization with maximal invariant predictor. arXiv preprint arXiv:2008.01883, 2020.
 Krishna et al. (2016) Krishna, R., Zhu, Y., Groth, O., Johnson, J., Hata, K., Kravitz, J., Chen, S., Kalantidis, Y., Li, L.J., Shamma, D. A., Bernstein, M., and FeiFei, L. Visual genome: Connecting language and vision using crowdsourced dense image annotations. 2016. URL https://arxiv.org/abs/1602.07332.
 Krueger et al. (2021) Krueger, D., Caballero, E., Jacobsen, J.H., Zhang, A., Binas, J., Zhang, D., Le Priol, R., and Courville, A. Outofdistribution generalization via risk extrapolation (rex). In International Conference on Machine Learning, pp. 5815–5826. PMLR, 2021.
 Lee et al. (2019) Lee, H. B., Nam, T., Yang, E., and Hwang, S. J. Meta dropout: Learning to perturb latent features for generalization. In International Conference on Learning Representations, 2019.
 Li et al. (2018) Li, H., Pan, S. J., Wang, S., and Kot, A. C. Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5400–5409, 2018.
 Liang & Zou (2021) Liang, W. and Zou, J. Metadataset: A dataset of datasets for evaluating distribution shifts and training conflicts. In ICML2021 ML4data Workshop, 2021.
 Liu et al. (2015) Liu, Z., Luo, P., Wang, X., and Tang, X. Deep learning face attributes in the wild. In ICCV, 2015.
 Long et al. (2015) Long, M., Cao, Y., Wang, J., and Jordan, M. Learning transferable features with deep adaptation networks. In International conference on machine learning, pp. 97–105. PMLR, 2015.
 Montanari et al. (2019) Montanari, A., Ruan, F., Sohn, Y., and Yan, J. The generalization error of maxmargin linear classifiers: Highdimensional asymptotics in the overparametrized regime. arXiv preprint arXiv:1911.01544, 2019.
 Muandet et al. (2013) Muandet, K., Balduzzi, D., and Schölkopf, B. Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. PMLR, 2013.
 Ni et al. (2019) Ni, J., Li, J., and McAuley, J. Justifying recommendations using distantlylabeled reviews and finegrained aspects. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLPIJCNLP), 2019.
 Qiao et al. (2020) Qiao, F., Zhao, L., and Peng, X. Learning to learn single domain generalization. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12556–12565, 2020.
 Rosenfeld et al. (2021) Rosenfeld, E., Ravikumar, P., and Risteski, A. The risks of invariant risk minimization. In ICLR, 2021.
 Sagawa et al. (2020a) Sagawa, S., Koh, P. W., Hashimoto, T. B., and Liang, P. Distributionally robust neural networks for group shifts: On the importance of regularization for worstcase generalization. In ICLR, 2020a.
 Sagawa et al. (2020b) Sagawa, S., Raghunathan, A., Koh, P. W., and Liang, P. An investigation of why overparameterization exacerbates spurious correlations. In ICML, pp. 8346–8356. PMLR, 2020b.
 Sanh et al. (2019) Sanh, V., Debut, L., Chaumond, J., and Wolf, T. Distilbert, a distilled version of bert: smaller, faster, cheaper and lighter. arXiv preprint arXiv:1910.01108, 2019.
 Shi et al. (2021) Shi, Y., Seely, J., Torr, P. H., Siddharth, N., Hannun, A., Usunier, N., and Synnaeve, G. Gradient matching for domain generalization. arXiv preprint arXiv:2104.09937, 2021.
 Shu et al. (2021) Shu, Y., Cao, Z., Wang, C., Wang, J., and Long, M. Open domain generalization with domainaugmented metalearning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9624–9633, 2021.
 Sun & Saenko (2016) Sun, B. and Saenko, K. Deep coral: Correlation alignment for deep domain adaptation. In European conference on computer vision, pp. 443–450. Springer, 2016.
 Taylor et al. (2019) Taylor, J., Earnshaw, B., Mabey, B., Victors, M., and Yosinski, J. Rxrx1: An image set for cellular morphological variation across many experimental batches. In International Conference on Learning Representations (ICLR), 2019.
 Tzeng et al. (2014) Tzeng, E., Hoffman, J., Zhang, N., Saenko, K., and Darrell, T. Deep domain confusion: Maximizing for domain invariance. arXiv preprint arXiv:1412.3474, 2014.
 Verma et al. (2019) Verma, V., Lamb, A., Beckham, C., Najafi, A., Mitliagkas, I., LopezPaz, D., and Bengio, Y. Manifold mixup: Better representations by interpolating hidden states. In International Conference on Machine Learning, pp. 6438–6447. PMLR, 2019.
 Volpi et al. (2018) Volpi, R., Namkoong, H., Sener, O., Duchi, J., Murino, V., and Savarese, S. Generalizing to unseen domains via adversarial data augmentation. arXiv preprint arXiv:1805.12018, 2018.
 Wah et al. (2011) Wah, C., Branson, S., Welinder, P., Perona, P., and Belongie, S. The CaltechUCSD Birds2002011 Dataset. Technical Report CNSTR2011001, California Institute of Technology, 2011.
 Wang et al. (2020) Wang, Y., Li, H., and Kot, A. C. Heterogeneous domain generalization via domain mixup. In ICASSP 20202020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 3622–3626. IEEE, 2020.

Xu et al. (2020)
Xu, M., Zhang, J., Ni, B., Li, T., Wang, C., Tian, Q., and Zhang, W.
Adversarial domain adaptation with domain mixup.
In
Proceedings of the AAAI Conference on Artificial Intelligence
, volume 34, pp. 6502–6509, 2020.  Yan et al. (2020) Yan, S., Song, H., Li, N., Zou, L., and Ren, L. Improve unsupervised domain adaptation with mixup training. arXiv preprint arXiv:2001.00677, 2020.
 Yue et al. (2019) Yue, X., Zhang, Y., Zhao, S., SangiovanniVincentelli, A., Keutzer, K., and Gong, B. Domain randomization and pyramid consistency: Simulationtoreal generalization without accessing target domain data. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 2100–2110, 2019.
 Yun et al. (2019) Yun, S., Han, D., Oh, S. J., Chun, S., Choe, J., and Yoo, Y. Cutmix: Regularization strategy to train strong classifiers with localizable features. In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 6023–6032, 2019.
 Zhang et al. (2018) Zhang, H., Cisse, M., Dauphin, Y. N., and LopezPaz, D. mixup: Beyond empirical risk minimization. 2018.
 Zhang et al. (2021a) Zhang, J., Menon, A., Veit, A., Bhojanapalli, S., Kumar, S., and Sra, S. Coping with label shift via distributionally robust optimisation. In ICLR, 2021a.
 Zhang et al. (2021b) Zhang, L., Deng, Z., Kawaguchi, K., Ghorbani, A., and Zou, J. How does mixup help with robustness and generalization? In ICLR, 2021b.
 Zhang et al. (2021c) Zhang, L., Deng, Z., Kawaguchi, K., and Zou, J. When and how mixup improves calibration. arXiv preprint arXiv:2102.06289, 2021c.
 Zhao et al. (2020) Zhao, L., Liu, T., Peng, X., and Metaxas, D. Maximumentropy adversarial data augmentation for improved generalization and robustness. arXiv preprint arXiv:2010.08001, 2020.

Zhou et al. (2017)
Zhou, B., Lapedriza, A., Khosla, A., Oliva, A., and Torralba, A.
Places: A 10 million image database for scene recognition.
IEEE transactions on pattern analysis and machine intelligence, 40(6):1452–1464, 2017.  Zhou et al. (2021) Zhou, C., Ma, X., Michel, P., and Neubig, G. Examining and combating spurious features under distribution shift. In ICML, 2021.
 Zhou et al. (2020a) Zhou, F., Jiang, Z., Shui, C., Wang, B., and Chaibdraa, B. Domain generalization with optimal transport and metric learning. arXiv preprint arXiv:2007.10573, 2020a.
 Zhou et al. (2020b) Zhou, K., Yang, Y., Hospedales, T., and Xiang, T. Deep domainadversarial image generation for domain generalisation. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 34, pp. 13025–13032, 2020b.
Appendix A Additional Experiments
a.1 Domain Shifts
a.1.1 Dataset Details
Datasets  Domains  Metric  Base Model  Num. of classes 
Camelyon17  5 hospitals  Avg. Acc.  DenseNet121  2 
FMoW  16 years x 5 regions  Worstgroup Acc.  DenseNet121  62 
RxRx1  51 experimental batches  Avg. Acc.  ResNet50  1,139 
Amazon  7,676 reviewers  10th Percentile Acc.  DistilBERTuncased  5 
MetaShift  4 backgrounds  Worstgroup Acc.  ResNet50  2 
In this section, we provide detailed descriptions of datasets used in the experiments of domain shifts and report the data statistics in Table 7.
Camelyon17
We use Camelyon17 from the WILDS benchmark (Koh et al., 2021; Bandi et al., 2018), which provides lymphnode scans sampled from hospitals. Camelyon17 is a medical image classification task where the input is a image and the label is whether there exists tumor tissue in the image. The domain denotes the hospital that the patch was taken from. The training dataset is drawn from the first hospitals, while outofdistribution validation and outofdistribution test datasets are sampled from the th hospital and th hospital respectively.
FMoW
The FMoW dataset is from the WILDS benchmark (Koh et al., 2021; Christie et al., 2018) — a satellite image classification task which includes classes and domains ( years x regions). Concretely, the input is a RGB satellite image, the label is one of the building or land use categories, and the domain represents the year that the image was taken as well as its corresponding geographical region – Africa, the Americas, Oceania, Asia, or Europe. The train/test/validation splits are based on the time when the images are taken. Specifically, images taken before 2013 are used as the training set. Images taken between 2013 and 2015 are used as the validation set. Images taken after 2015 are used for testing.
RxRx1
RxRx1 (Koh et al., 2021; Taylor et al., 2019) from the WILDS benchmark is a cell image classification task. In the dataset, some cells have been genetically perturbed by siRNA. The goal of RxRx1 is to predict which siRNA that the cells have been treated with. Concretely, the input is an image of cells obtained by fluorescent microscopy, the label indicates which of the genetic treatments the cells received, and the domain denotes the experimental batches. Here, different batches of images are used for training, where each batch contains one sample for each class. The outofdistribution validation set has images from experimental batches. The outofdistribution test set has experimental batches. The average accuracy on outofdistribution test set is reported.
Amazon
Each task in the Amazon benchmark (Koh et al., 2021; Ni et al., 2019) is a multiclass sentiment classification task. The input is the text of a review, the label is the corresponding star rating ranging from 1 to 5, and the domain is the corresponding reviewer. The training set has reviews from reviewers, while the outofdistribution validation set has reviews from another reviewers. The outofdistribution test set also has reviews from the rest reviewers. We evaluate the models by the 10th percentile of peruser accuracies in the test set.
MetaShift
We use the MetaShift (Liang & Zou, 2021), which is derived from Visual Genome (Krishna et al., 2016). MetaShift leverages the natural heterogeneity of Visual Genome to provide many distinct data distributions for a given class (e.g. “cats with cars” or “cats in bathroom” for the “cat” class). A key feature of MetaShift is that it provides explicit explanations of the dataset correlation and a distance score to measure the degree of distribution shift between any pair of sets.
We adopt the “Cat vs. Dog” task in MetaShift, where we evaluate the model on the “dog(shelf)” domain with 306 images, and the “cat(shelf)” domain with 235 images. The training data for the “Cat” class is the cat(sofa + bed), including cat(sofa) domain and cat(bed) domain. MetaShift provides 4 different sets of training data for the “Dog” class in an increasingly challenging order, i.e., increasing the amount of distribution shift. Specifically, dog(cabinet + bed), dog(bag + box), dog(bench + bike), dog(boat + surfboard) are selected for training, where their corresponding distances to dog(shelf) are 0.44, 0.71, 1.12, 1.43.
a.1.2 Training Details
Follow WILDS Koh et al. (2021), we adopt pretrained DenseNet121 (Huang et al., 2017) for Camelyon17 and FMoW datasets, ResNet50 (He et al., 2016) for RxRx1 and MetaShift datasets, and DistilBert (Sanh et al., 2019) for Amazon datasets.
In each training iteration, we first draw a batch of samples from the training set. With , we then select another sample batch with same labels as for data interpolation. The interpolation ratio is drawn from the distribution . We use the same image transformers as Koh et al. (2021), and all other hyperparameter settings are listed in Table 8.
Dataset  Camelyon17  FMoW  RxRx1  Amazon  MetaShift 
Learning rate  1e4  1e4  1e3  2e6  1e3 
Weight decay  0  0  1e5  0  1e4 
Scheduler  n/a  n/a  Cosine Warmup  n/a  n/a 
Batch size  32  32  72  8  16 
Type of mixup  CutMix  CutMix  CutMix  ManifoldMix  CutMix 
Architecture  DenseNet121  DenseNet121  ResNet50  DistilBert  ResNet50 
Optimizer  SGD  Adam  Adam  Adam  SGD 
Maximum Epoch 
2  5  90  3  100 
Strategy sel. prob.  1.0  1.0  1.0  1.0  1.0 
a.1.3 Full Results of WILDS data
Follow Koh et al. (2021), we reported more results on WILDS datasets in Table 9  Table 12, including validation performance and the results of other metrics. According to these additional results, we could see that LISA outperforms other baseline approaches in all scenarios. Particularly, we here discuss two additional findings: (1) In Camelyon dataset, the test data is much more visually distinctive compared with the validation data, resulting in the large gap () between validation and test performance of ERM (see Table 9). However, LISA significantly reduces the performance gap between the validation and test sets, showing its promise in improving OOD robustness; (2) In Amazon dataset, though LISA performs worse than ERM in average accuracy, it achieves the best accuracy at the 10th percentile, which is regarded as a more common and important metric to evaluate whether models perform consistently well across all users (Koh et al., 2021).
Validation Acc.  Test Acc.  
ERM  84.9 3.1%  70.3 6.4% 
IRM  86.2 1.4%  64.2 8.1% 
Coral  86.2 1.4%  59.5 7.7% 
GroupDRO  85.5 2.2%  68.4 7.3% 
DomainMix  83.5 1.1%  69.7 5.5% 
Fish  83.9 1.2%  74.7 7.1% 
LISA (ours)  81.8 1.3%  77.1 6.5% 
Validation  Test  
Avg. Acc.  Worst Acc.  Avg. Acc.  Worst Acc.  
ERM  59.5 0.37%  48.9 0.62%  53.0 0.55%  32.3 1.25% 
IRM  57.4 0.37%  47.5 1.57%  50.8 0.13%  30.0 1.37% 
Coral  56.9 0.25%  47.1 0.43%  50.5 0.36%  31.7 1.24% 
GroupDRO  58.8 0.19%  46.5 0.25%  52.1 0.50%  30.8 0.81% 
DomainMix  58.6 0.29%  48.9 1.15%  51.6 0.19%  34.2 0.76% 
Fish  57.8 0.15%  49.5 2.34%  51.8 0.32%  34.6 0.18% 
LISA (ours)  58.7 0.92%  48.7 0.74%  52.8 0.94%  35.5 0.65% 
Validation Acc.  Test ID Acc.  Test OOD Acc.  
ERM  19.4 0.2%  35.9 0.4%  29.9 0.4% 
IRM  5.6 0.4%  9.9 1.4%  8.2 1.1% 
Coral  18.5 0.4%  34.0 0.3%  28.4 0.3% 
GroupDRO  15.2 0.1%  28.1 0.3%  23.0 0.3% 
DomainMix  19.3 0.7%  39.8 0.2%  30.8 0.4% 
Fish  7.5 0.6%  12.7 1.9%  10.1 1.5% 
LISA (ours)  20.1 0.4%  41.2 1.0%  31.9 0.8% 
Validation  Test  
Avg. Acc.  10th Per.  Avg. Acc.  10th Per. Acc.  
ERM  72.7 0.1%  55.2 0.7%  71.9 0.1%  53.8 0.8% 
IRM  71.5 0.3%  54.2 0.8%  70.5 0.3%  52.4 0.8% 
Coral  72.0 0.3%  54.7 0.0%  70.0 0.6%  52.9 0.8% 
GroupDRO  70.7 0.6%  54.7 0.0%  70.0 0.6%  53.3 0.0% 
DomainMix  71.9 0.2%  54.7 0.0%  71.1 0.1%  53.3 0.0% 
Fish  72.5 0.0%  54.7 0.0%  71.7 0.1%  53.3 0.0% 
LISA (ours)  71.2 0.3%  55.1 0.61%  70.6 0.3%  54.7 0.0% 
a.2 Subpopulation Shifts
a.2.1 Dataset Details
We detail the data descriptions of subpopulation shifts below and report the detailed data statistics in Table 13.
Colored MNIST (CMNIST): We classify MNIST digits from 2 classes, where classes 0 and 1 indicate original digits (0,1,2,3,4) and (5,6,7,8,9). The color is treated as a spurious attribute. Concretely, in the training set, the proportion between red samples and green samples is 8:2 in class 0, while the proportion is set as 2:8 in class 1. In the validation set, the proportion between green and red samples is 1:1 for all classes. In the test set, the proportion between green and red samples is 1:9 in class 0, while the ratio is 9:1 in class 1. The data sizes of train, validation, and test sets are 30000, 10000, and 20000, respectively.
Waterbirds (Sagawa et al., 2020a): The Waterbirds dataset aims to classify birds as “waterbird” or “landbird”, where each bird image is spuriously associated with the background “water” or “land”. Waterbirds is a synthetic dataset where each image is composed by pasting a bird image sampled from CUB dataset (Wah et al., 2011) to a background drawn from the Places dataset (Zhou et al., 2017). The bird categories in CUB are stratified as land birds or water birds. Specifically, the following bird species are selected to construct the waterbird class: albatross, auklet, cormorant, frigatebird, fulmar, gull, jaeger, kittiwake, pelican, puffin, tern, gadwall, grebe, mallard, merganser, guillemot, or Pacific loon. All other bird species are combined as the landbird class. We define (land background, waterbird) and (water background, landbird) are minority groups. There are 4,795 training samples while only 56 samples are “waterbirds on land” and 184 samples are “landbirds on water”. The remaining training data include 3,498 samples from “landbirds on land”, and 1,057 samples from “waterbirds on water”.
CelebA (Liu et al., 2015; Sagawa et al., 2020a): For the CelebA data (Liu et al., 2015), we follow the data preprocess procedure from (Sagawa et al., 2020a). CelebA defines a image classification task where the input is a face image of celebrities and the classification label is its corresponding hair color – “blond” or “not blond.” The label is spuriously correlated with gender, i.e., male or female. In CelebA, the minority groups are (blond, male) and (not blond, female). The number of samples for each group are 71,629 “dark hair, female”, 66,874 “dark hair, male”, 22,880 “blond hair, female”, 1,387 “blond hair, male”.
CivilComments (Borkan et al., 2019; Koh et al., 2021): We use CivilComments from the WILDS benchmark (Koh et al., 2021). CivilComments is a text classification task, aiming to predict whether an online comment is toxic or nontoxic. The spurious domain identifications are defined as the demographic features, including male, female, LGBTQ, Christian, Muslim, other religion, Black, and White. CivilComments contains 450,000 comments collected from online articles. The number of samples for training, validation, and test are 269,038, 45,180, and 133,782, respectively. The readers may kindly refer to Table 17 in (Koh et al., 2021) for the detailed group information.
Datasets  Domains  Base Model  Class Information 
CMNIST  2 digit colors  ResNet50  digit (0,1,2,3,4) v.s. (5,6,7,8,9) 
Waterbirds  2 backgrounds  ResNet50  waterbirds v.s. landbirds 
CelebA  2 hair colors  ResNet50  man v.s. women 
CivilComments  8 demographic identities  DistilBERTuncased  toxic v.s. nontoxic 
a.2.2 Training Details
We adopt pretrained ResNet50 (He et al., 2016) and BERT (Sanh et al., 2019) as the model for image data (i.e., CMNIST, Waterbirds, CelebA) and text data (i.e., CivilComments), respectively. In each training iteration, we sample a batch of data per group. For intralabel LISA, we randomly apply mixup on sample batches with the same labels but different domains. For intradomain LISA, we instead apply mixup on sample batches with the same domain but different labels. The interpolation ratio is sampled from the distribution . All hyperparameters are listed in Table 14.
a.2.3 Additional Results
Dataset  CMNIST  Waterbirds  CelebA  CivilComments 
Learning rate  1e3  1e3  1e4  1e5 
Weight decay  1e4  1e4  1e4  0 
Scheduler  n/a  n/a  n/a  n/a 
Batch size  16  16  16  8 
Type of mixup  mixup  mixup  CutMix  ManifoldMix 
Architecture  ResNet50  ResNet50  ResNet50  DistilBert 
Optimizer  SGD  SGD  SGD  Adam 
Maximum Epoch  300  300  50  3 
Strategy sel. prob.  0.5  0.5  0.5  1.0 
CMNIST  Waterbirds  
Avg.  Worst  Avg.  Worst  
ERM  27.8 1.9%  0.0 0.0%  97.0 0.2%  63.7 1.9% 
UW  72.2 1.1%  66.0 0.7%  95.1 0.3%  88.0 1.3% 
IRM  72.1 1.2%  70.3 0.8%  87.5 0.7%  75.6 3.1% 
Coral  71.8 1.7%  69.5 0.9%  90.3 1.1%  79.8 1.8% 
GroupDRO  72.3 1.2%  68.6 0.8%  91.8 0.3%  90.6 1.1% 
DomainMix  51.4 1.3%  48.0 1.3%  76.4 0.3%  53.0 1.3% 
Fish  46.9 1.4%  35.6 1.7%  85.6 0.4%  64.0 0.3% 
LISA  74.0 0.1%  73.3 0.2%  91.8 0.3%  89.2 0.6% 
CelebA  CivilComments  
Avg.  Worst  Avg.  Worst  
ERM  94.9 0.2%  47.8 3.7%  92.2 0.1%  56.0 3.6% 
UW  92.9 0.2%  83.3 2.8%  89.8 0.5%  69.2 0.9% 
IRM  94.0 0.4%  77.8 3.9%  88.8 0.7%  66.3 2.1% 
Coral  93.8 0.3%  76.9 3.6%  88.7 0.5%  65.6 1.3% 
GroupDRO  92.1 0.4%  87.2 1.6%  89.9 0.5%  70.0 2.0% 
DomainMix  93.4 0.1%  65.6 1.7%  90.9 0.4%  63.6 2.5% 
Fish  93.1 0.3%  61.2 2.5%  89.8 0.4%  71.1 0.4% 
LISA (ours)  92.4 0.4%  89.3 1.1%  89.2 0.9%  72.6 0.1% 
CMNIST  Waterbirds  
Avg.  Worst  Avg.  Worst  
ERM  27.8 1.9%  0.0 0.0%  97.0 0.2%  63.7 1.9% 
Vanilla mixup  32.6 3.1%  3.1 2.4%  81.0 0.2%  56.2 0.2% 
Vanilla mixup + UW  72.2 0.7%  71.8 0.1%  92.1 0.1%  85.6 1.0% 
Ingroup Group  33.6 1.9%  24.0 1.1%  88.7 0.3%  68.0 0.4% 
Ingroup + UW  72.6 0.1%  71.6 0.2%  91.4 0.6%  87.1 0.6% 
LISA (ours)  74.0 0.1%  73.3 0.2%  91.8 0.3%  89.2 0.6% 
CelebA  CivilComments  
Avg.  Worst  Avg.  Worst  
ERM  94.9 0.2%  47.8 3.7%  92.2 0.1%  56.0 3.6% 
Vanilla mixup  95.8 0.0%  46.4 0.5%  90.8 0.8%  67.2 1.2% 
Vanilla mixup + UW  91.5 0.2%  88.0 0.3%  87.8 1.2%  66.1 1.4% 
Within Group  95.2 0.3%  58.3 0.9%  90.8 0.6%  69.2 0.8% 
Within Group + UW  92.4 0.4%  87.8 0.6%  84.8 0.7%  69.3 1.1% 
LISA (ours)  92.4 0.4%  89.3 1.1%  89.2 0.9%  72.6 0.1% 
a.3 Results on Datasets without Spurious Correlations
In order to analyze the factors that lead to the performance gains of LISA, we conduct experiments on datasets without spurious correlations. To be more specific, we balance the number of samples for each group under the subpopulation shifts setting. The results of ERM, Vanilla mixup and LISA on CMNIST, Waterbirds and CelebA are reported in Table 17. The results show that LISA performs similarly compared with ERM when datasets do not have spurious correlations. If there exists any spurious correlation, LISA significantly outperforms ERM. Another interesting finding is that Vanilla mixup outperforms LISA and ERM without spurious correlations, while LISA achieves the best performance with spurious correlations. This finding strengthens our conclusion that the performance gains of LISA are caused by eliminating spurious correlations rather than data augmentation.
Dataset  ERM  Vanilla mixup  LISA 
CMNIST  73.67%  74.28%  73.18% 
Waterbirds  88.07%  88.23%  87.05% 
CelebA  86.11%  88.89%  87.22% 
a.4 Analysis of Representationlevel Invariance
For each label
, assume the hidden representation for each domain
as . Follow the analysis of functionlevel invariance, the representationlevel invariance is measured by the pairwise KL divergence of distribution among all domains as , where smaller values indicate that the learned representations are more invariant with respect to the labels. We report the results on CMNIST, Waterbirds and MetaShift in Table 18. Our key observations are: (1) Compared with ERM, LISA learns stronger representationlevel invariance. The potential reason is that functionlevel invariance leads to better representationlevel invariance. (2) LISA has greater invariance than vanilla mixup, validating that the invariant representations are not caused by naive data interpolation. (3) LISA provides more invariant representations than regularizationbased methods, i.e., IRM and DomainMix.CMNIST  Waterbirds  MetaShift  
ERM  1.683  3.592  0.632 
Vanilla mixup  4.392  3.935  0.634 
IRM  1.905  2.413  0.627 
DomainMix  2.155  3.716  0.614 
LISA (ours)  0.421  1.912  0.585 
Besides the quantitative analysis, follow Appendix C in (Lee et al., 2019), we visualize the hidden representations for all test samples and the decision boundary on Waterbirds in Figure 2. Compared with other methods, the representations of samples with the same label that learned by LISA are closer regardless of their domain information, which further demonstrates the promise of LISA in learning more invariant representations.
Appendix B Proofs of Theorem 1 and Theorem 2
Outline of the proof. We will first find the misclassification errors based on the population version of OLS with different mixup strategies. Next, we will develop the convergence rate of the empirical OLS based on samples towards its population version. These two steps together give us the empirical misclassification errors of different methods. We will separately show that the upper bounds in Theorem 1 and Theorem 2 hold for two selective augmentation strategies of LISA and hence hold for any . Let LL denote intralabel LISA and LD denote intradomain LISA.
Let and denote the marginal class proportions in the training samples. Let and denote the marginal subpopulation proportions in the training samples. Let and define , , and similarly.
We consider the setting where is relatively small and .
b.1 Decomposing the loss function
Recall that . We further define , , and .
For the mixup estimators, we will repeatedly use the fact that
has a symmetric distribution with support .For ERM estimator based on , where , we have
Notice that based on the estimator
b.2 Classification errors of four methods with infinite training samples
We first provide the limit of the classification errors when .
b.2.1 Baseline method: ERM
For the training data, it is easy to show that
For and , the ERM has slope and intercept being
In the extreme case where , we have
Hence,
(5) 
where is computed via ERM.
b.2.2 Baseline method: Vanilla mixup
The vanilla mixup does not use the group information. Let be a random draw from . Let be a random draw from independent of . Let
and
We can find that
Comments
There are no comments yet.