Improving Out-of-Distribution Robustness via Selective Augmentation

01/02/2022
by   Huaxiu Yao, et al.
109

Machine learning algorithms typically assume that training and test examples are drawn from the same distribution. However, distribution shift is a common problem in real-world applications and can cause models to perform dramatically worse at test time. In this paper, we specifically consider the problems of domain shifts and subpopulation shifts (eg. imbalanced data). While prior works often seek to explicitly regularize internal representations and predictors of the model to be domain invariant, we instead aim to regularize the whole function without restricting the model's internal representations. This leads to a simple mixup-based technique which learns invariant functions via selective augmentation called LISA. LISA selectively interpolates samples either with the same labels but different domains or with the same domain but different labels. We analyze a linear setting and theoretically show how LISA leads to a smaller worst-group error. Empirically, we study the effectiveness of LISA on nine benchmarks ranging from subpopulation shifts to domain shifts, and we find that LISA consistently outperforms other state-of-the-art methods.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 2

07/06/2020

Estimating Generalization under Distribution Shifts via Domain-Invariant Representations

When machine learning models are deployed on a test distribution differe...
05/13/2021

Causally-motivated Shortcut Removal Using Auxiliary Labels

Robustness to certain distribution shifts is a key requirement in many M...
10/18/2021

MEMO: Test Time Robustness via Adaptation and Augmentation

While deep neural networks can attain good accuracy on in-distribution t...
06/29/2020

The Many Faces of Robustness: A Critical Analysis of Out-of-Distribution Generalization

We introduce three new robustness benchmarks consisting of naturally occ...
09/25/2019

Domain-invariant Learning using Adaptive Filter Decomposition

Domain shifts are frequently encountered in real-world scenarios. In thi...
12/17/2020

DecAug: Out-of-Distribution Generalization via Decomposed Feature Representation and Semantic Augmentation

While deep learning demonstrates its strong ability to handle independen...
12/22/2020

Can I Still Trust You?: Understanding the Impact of Distribution Shifts on Algorithmic Recourses

As predictive models are being increasingly deployed to make a variety o...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

To deploy machine learning algorithms in real-world applications, we must pay attention to distribution shift, i.e. when the test distribution is different from the training distribution, which substantially degrades model performance. In this paper, we refer this problem as out-of-distribution (OOD) generalization and specifically consider performance gaps caused by two kinds of distribution shifts: domain shifts and subpopulation shifts. In domain shifts, the test data is sampled from different domains than the training data, which requires the trained model to generalize well to test domains without seeing the data from those domains in training time. Take health risk prediction as an example. We may want to train a model on patients from a few sampled hospitals and then deploy the model to a broader set of hospitals (Koh et al., 2021). In subpopulation shifts, the proportions of subpopulations in the test distribution differ from the proportions in the training distribution (eg. the setting of imbalanced data). When subpopulation shift occurs, models perform poorly when they falsely rely on spurious correlations, which may occur when some subpopulations are under-represented in the training set. For example, in financial risk prediction, a machine learning model trained on the entire population may associate the labels with demographic features (e.g., religion and race), making the model fail on the test set when such an association does not hold in reality.

To improve model robustness under these two kinds of distribution shifts, methods for learning invariance have shown effectiveness in various applications. These methods use regularizers to learn features or predictors that are invariant across different domains while still containing sufficient information of fulfilling the task (Li et al., 2018; Sun & Saenko, 2016; Arjovsky et al., 2019; Krueger et al., 2021; Rosenfeld et al., 2021). However, designing regularizers that are widely suitable to datasets from diverse domains is especially challenging and insuitable regularizers may adversely limit the model’s expressive power or yield a difficult optimization problem, leading to inconsistent performance among various real-world datasets. For example, on the WILDS datasets, invariant risk minimization (IRM) (Arjovsky et al., 2019) with reweighting outperforms empirical risk minimization (ERM) on CivilComments, but fails to improve robustness on a variety of other datasets like Camelyon17 and RxRx1 (Koh et al., 2021). CORAL  (Sun & Saenko, 2016), a popular domain adaptation method that adds explicit regularization, has also been observed to have similar phenomenon (Koh et al., 2021).

Figure 1:

Illustration of the variants of LISA (Intra-label LISA and Intra-domain LISA) on Colored MNIST dataset.

represents the interpolation ratio, which is sampled from a Beta distribution. (a) Colored MNIST (CMNIST). We classify MNIST digits as two classes, and original digits (0,1,2,3,4) and (5,6,7,8,9) are labeled as class 0 and 1, respectively. Digit color is used as domain information, which is spuriously correlated with labels in training data; (b) Intra-label LISA (LISA-L) cancels out spurious correlation by interpolating samples with the same label; (c) Intra-domain LISA (LISA-D) interpolates samples with the same domain but different labels, where the domain information cannot lead the change of labels and will be ignored by the model.

Instead of explicitly imposing regularization to learn invariant representations or predictors, we turn towards an implicit solution in this paper. Inspired by mixup (Zhang et al., 2018), we aim to learn invariant functions through data interpolation, leading to a simple algorithm called LISA (Learning Invariant Functions with Selective Augmentation). Concretely, LISA linearly interpolates the features for a pair of samples and applies the same interpolation strategy on the corresponding labels. Critically, the pairs are selectively chosen according to two selective augmentation strategies – intra-label LISA (LISA-L) and intra-domain LISA (LISA-D), which are described below and we illustrate the motivations of both strategies on Colored MNIST dataset in Figure 1. Intra-label LISA (Figure 1(b)) interpolates samples with the same label but from different domains, aiming to eliminate domain-related spurious correlations. Intra-domain LISA (Figure 1(c)) interpolates samples with the same domain but different labels, where the model should to ignore the domain information and generate different predicted values as the interpolation ratio changes. In this way, LISA encourages the model to learn domain-invariant functions without any explicitly constraints or regularizers.

The primary contributions of this paper are as follows: (1) We propose a simple yet widely-applicable method for learning domain invariant functions that are shown to be robust to domain shift and subpopulation shift. (2) We conduct broad empirical experiments to evaluate the effectiveness of LISA on nine benchmark datasets from diverse domains. In these experiments, we make the following observations. First, we find that LISA consistently outperforms seven prior methods in addressing both domain shifts and subpopulation shifts. Second, we identify that the performance gains of LISA are indeed caused by canceling out domain-specific information (or spurious correlations) and learning invariant functions, rather than simply involving more data via interpolation. Third, when the degree of distribution shift increases, LISA achieves more significant performance gains. (3) Finally, we provide theoretical analysis of the phenomena distilled from the empirical studies, where we provably demonstrate that LISA can mitigate the spurious correlations and therefore lead to smaller worst-domain error compared with ERM and vanilla mixup. We also note that to the best of our knowledge, our work provides the first theoretical framework of studying how mixup (with or without the selective augmentation strategies) affects mis-classification error.

2 Preliminaries

In this paper, we consider the setting where one predicts the label based on the input feature . Given a parameter space

and a loss function

, we need to train a model under the training distribution , where . In empirical risk minimization (ERM), the empirical distribution over the training data is ; ERM optimizes the following objective:

(1)

In a traditional machine learning setting, a test set, sampled from a test distribution , is used to evaluate the generalization of the trained model , where the test distribution is assumed to be the same as the training distribution, i.e., . In this paper, we are interested in the setting when distribution shift occurs, i.e., .

Specifically, following Muandet et al. (2013); Albuquerque et al. (2019); Koh et al. (2021), we regard the overall data distribution containing domains and each domain is associated with a data distribution over a set , where is the number of samples in domain . Then, we formulate the training distribution as the mixture of domains, i.e., , where

denotes the mixture probabilities in training set. Here, the training domains are defined as

. Similarly, the test distribution could be represented as , where is the mixture probabilities in test set. The test domains are defined as .

In domain shift scenarios, we investigate the problem where the test domains are disjoint from the training domains, i.e., . In general, we assume the test domains share some common properties with the training domains. For example, in Camelyon17 (Koh et al., 2021), we train the model on some hospitals and test it in a new hospital. We evaluate the worst-domain and average performance of the classifier across all test domains.

In subpopulation shift scenarios, the test set has domains that have been seen in the training set, but with a different proportion of subpopulations, i.e., but . Under this setting, following Sagawa et al. (2020a), we specially consider group-based spurious correlations, where each group is defined to be associated with a domain and a label , i.e., . We assume that the domain is spuriously correlated with the label. For example, we illustrate the CMNIST dataset in Figure 1, where the digit color (green or red) is spuriously correlated with the label ([1, 0] or [0, 1]). Based on the group definition, we evaluate the model via the worst test group error, i.e., , where represents the 0-1 loss.

3 Learning Invariant Functions with Selective Augmentation

This section presents LISA, a simple way to improve robustness to domain shifts or subpopulation shifts. The key idea behind LISA is to encourage the model to learn invariant functions by selective data interpolation, which could also alleviates the effects of domain-related spurious correlations. Before detailing how to select interpolated samples, we first provide a general formulation for data interpolation.

In LISA, we perform linear interpolation between training samples. Specifically, given samples and drawn from domains and , we apply mixup (Zhang et al., 2018), a simple data interpolation strategy, separately on the input features and corresponding labels as:

(2)

where the interpolation ratio is sampled from a Beta distribution and and

are one-hot vectors for classification problem.

Notice that the mixup approach in equation 2 can be replaced by CutMix (Yun et al., 2019), which shows stronger empirical performance in vision-based applications. In text-based applications, we can use Manifold Mixup (Verma et al., 2019), interpolating the representations of a pre-trained model, e.g., the output of BERT (Devlin et al., 2019).

After obtaining the interpolated features and labels, we replace the original features and labels in ERM with the interpolated ones. Then, the optimization process in equation 1 is reformulated as:

(3)

Without additional selective augmentation strategies, vanilla mixup will regularize the model and reduce overfitting (Zhang et al., 2021b), allowing it to attain good in-distribution generalization. However, vanilla mixup may not be able to cancel out spurious correlations, causing the model to still fail at attaining good OOD generalization (see empirical comparisons in Section 4.3 and theoretical discussion in Section 5). In LISA, we instead adopt a new strategy where mixup is only applied across specific domains or groups, which leans towards learning invariant functions and thus better OOD performance. Specifically, the two kinds of selective augmentation strategies are presented as follows:

Intra-label LISA (LISA-L): Interpolating samples with the same label. Intra-label LISA interpolates samples with the same label but different domains (i.e., , ). As shown in Figure 1(a), this produces datapoints that have both domains partially present, effectively eliminating spurious correlations between domain and label in cases where the pair of domains correlate differently with the label. Additionally, if domain information does not fully reflect the spurious correlations in some datasets, we can also enlarge the interpolation scope to cover more potentially spurious correlations by only interpolating samples within the same class regardless of domain information (i.e., ). As a result, intra-label LISA should learn domain-invariant functions for each class and thus achieve better OOD robustness.

Intra-domain LISA (LISA-D): Interpolating samples with the same domain. Supposing domain information is highly spuriously correlated with the label information, intra-domain LISA (Figure 1(b)) applies the interpolation strategy on samples with the same domain but different labels, i.e., , . Intuitively, even within the same domain, the model is supposed to generate different predicted labels since the interpolation ratio is randomly sampled, corresponding to different labels . This causes the model to make predictions that are less dependent on the domain, again improving OOD robustness.

For each batch of data, we randomly perform intra-label or intra-domain LISA during the training process with probability and , where

is treated as a hyperparameter in our experiments. The choice of

depends on the number of domains and the relation between domain information and spurious correlations. Empirically, using intra-label LISA brings much more benefits when there are more domains, and/or the domain information can not fully reflect the spurious correlations. Intra-domain LISA can benefit the performance when domain information is highly spuriously correlated with the label, where we find a balanced ratio (i.e., ) performs the best. The pseudocode of the training procedure of LISA is shown in Algorithm 1.

0:  Training data, step size , learning rate , Shape parameters , of Beta distribution
1:  while not converge do
2:     Sample
3:     Sample a set of samples uniformly from the training data
4:     Randomly select intra-label or intra-domain LISA with the probability and
5:     if use intra-label LISA (LISA-L) then
6:        For each sample (, , ), find another one (, , ) from the dataset with the same label () but different domains (), and construct set .
7:     else if use intra-domain LISA (LISA-D) then
8:        Randomly sample a domain
9:        For each sample (, , ) in domain , find another one (, , ) from the same domain () but different labels (), constructing set .
10:     Update with data with learning rate .
Algorithm 1 Training Procedure of LISA

4 Experiments

In this section, we conduct comprehensive experiments to evaluate the effectiveness of LISA. Specifically, we aim to answer the following questions: Q1: Compared to prior methods, can LISA improve robustness to domain shifts and subpopulation shifts (Section 4.1 and Section 4.2)? Q2: Which aspects of LISA are the most important for improving robustness (Section 4.3)? Q3: How does LISA perform with varying degrees of distribution shifts (Section 4.4)? Q4: Does LISA successfully produce invariant functions (Section 4.5)?

To answer Q1, we compare to ERM, IRM (Arjovsky et al., 2019), MMD (Li et al., 2018), DRNN (Ganin & Lempitsky, 2015), GroupDRO (Sagawa et al., 2020a), DomainMix (Xu et al., 2020), and Fish (Shi et al., 2021). Upweighting (UW) is particularly suitable for subpopulation shifts, so we also use it for comparison. We adopt the same model architectures for all approaches.

4.1 Evaluating Robustness to Domain Shifts

Experimental Setup. To study domain shifts, we study five datasets. Four datasets (Camelyon17, FMoW, RxRx1, and Amazon) are selected from WILDS (Koh et al., 2021)

, covering real-world distribution shifts across diverse domains (e.g., health, natural language process, and vision). Besides the WILDS data, we also apply LISA on the MetaShift datasets 

(Liang & Zou, 2021), constructed using the real-world images and natural heterogeneity of Visual Genome (Krishna et al., 2016). We summarize the datasets in Table 7 of Appendix A.1.1

, including domain information, evaluation metric, model architecture, and the number of classes. Detailed dataset descriptions and other training details are discussed in Appendix 

A.1.1 and A.1.2, respectively.

Results. We report the results of domain shifts in Table 1, where the full results that include validation performance and other metrics are listed in Appendix A.1.3. The optimal strategy selection probability is set as for these domain shifts problems, i.e., only intra-label LISA is used. In addition, we only interpolate samples with the same labels regardless of the domain information, which empirically leads to the best performance. According to Table 1, we have two key findings:

(1) There are no significant performance differences between ERM and other invariant learning methods (e.g., IRM, CORAL, DomainMix) on most datasets, which is consistent with the reported results on WILDS (Koh et al., 2021). A potential reason is that the existing domain information may not fully reflects the spurious correlation. For example, in Camelyon17-wilds dataset, the presence of tumor tissue (i.e., label) mainly depends on the demographic of patients (e.g., race, gender), which shows no significant difference across hospitals (i.e., domain information). This could also explain why only interpolating samples with the same labels regardless the domain information achieves the best performance, which also corroborates our claims in the description of intra-label LISA in Section 3.

(2) The consistent superiority of LISA outperforms prior methods on all five datasets regardless of the model architecture and dataset types (i.e., image or text), demonstrating its effectiveness in improving OOD robustness by canceling out both observable and unobservable domain-related correlations with selective augmentation.

Camelyon17 FMoW RxRx1 Amazon MetaShift
Avg. Acc. Worst Acc. Avg. Acc. 10-th Per. Acc. Worst Acc.
ERM 70.3 6.4% 32.3 1.25% 29.9 0.4% 53.8 0.8% 52.1 0.4%
IRM 64.2 8.1% 30.0 1.37% 8.2 1.1% 52.4 0.8% 51.8 0.8%
CORAL 59.5 7.7% 31.7 1.24% 28.4 0.3% 52.9 0.8% 47.6 1.9%
GroupDRO 68.4 7.3% 30.8 0.81% 23.0 0.3% 53.3 0.0% 51.9 0.7%
DomainMix 69.7 5.5% 34.2 0.76% 30.8 0.4% 53.3 0.0% 51.3 0.5%
Fish 74.7 7.1% 34.6 0.18% 10.1 1.5% 53.3 0.0% 49.2 2.1%
LISA (ours) 77.1 6.5% 35.5 0.65% 31.9 0.8% 54.7 0.0% 54.2 0.7%
Table 1: Main domain shifts results. LISA outperforms prior methods on all five datasets. Following the instructions of Koh et al. (2021), we report the performance of Camelyon17 over 10 different seeds and the results of other datasets are obtained over 3 different seeds.
CMNIST Waterbirds CelebA CivilComments
Avg. Worst Avg. Worst Avg. Worst Avg. Worst
ERM 27.8% 0.0% 97.0% 63.7% 94.9% 47.8% 92.2% 56.0%
UW 72.2% 66.0% 95.1% 88.0% 92.9% 83.3% 89.8% 69.2%
IRM 72.1% 70.3% 87.5% 75.6% 94.0% 77.8% 88.8% 66.3%
CORAL 71.8% 69.5% 90.3% 79.8% 93.8% 76.9% 88.7% 65.6%
GroupDRO 72.3% 68.6% 91.8% 90.6% 92.1% 87.2% 89.9% 70.0%
DomainMix 51.4% 48.0% 76.4% 53.0% 93.4% 65.6% 90.9% 63.6%
Fish 46.9% 35.6% 85.6% 64.0% 93.1% 61.2% 89.8% 71.1%
LISA (ours) 74.0% 73.3% 91.8% 89.2% 92.4% 89.3% 89.2% 72.6%
Table 2:

Results of subpopulation shifts. Here, we show the average and worst group accuracy. We repeat the experiments three times and put full results with standard deviation in Table 

15.

4.2 Evaluating Robustness to Subpopulation Shifts

Evaluation Protocol. In subpopulation shifts, we evaluate the performance on four binary classification datasets, including Colored MNIST (CMNIST), Waterbirds (Sagawa et al., 2020a), CelebA (Liu et al., 2015), and Civilcomments (Borkan et al., 2019). We summarize brief data statistics in Table 13 of Appendix A.2.1, covering domain information, model architecture, and class information. Full dataset descriptions of subpopulation shifts are also presented in Appendix A.2.1. Following Sagawa et al. (2020a), in subpopulation shifts, we use the worst-group accuracy to evaluate the performance of all approaches and the domain identifications are highly spurious correlated with the label information. For example, as suggested in Figure 1, 90% images in the CMNIST dataset have the same color and class, i.e., green color for label 0 and red color for label 1. Full hyperparameter settings and training details are listed in Appendix A.2.2.

Results. In Table 2, we report the overall performance of LISA and other methods. Similar to the observations in domain shifts, LISA consistently outperforms prior methods in CMNIST, CelebA, and CivilComments. In Waterbirds, LISA outperforms other invariant learning methods (e.g., IRM, CORAL, DomainMix, Fish) and shows similar performance to GroupDRO. These results demonstrate the effectiveness of LISA in improving OOD robustness. In CMNIST, Waterbirds, and CelebA, we find that works well for choosing selective augmentation strategies, while is set as in CivilComments. This is not surprising because it might be more beneficial to use intra-label LISA more often to eliminate domain effects when there are more domains, i.e., eight domains in CivilComments v.s. two domains in others.

4.3 Ablation Study: Is the Performance Gain from Data Augmentation?

Camelyon17 FMoW RxRx1 Amazon MetaShift
Avg. Acc. Worst Acc. Avg. Acc. 10-th Per. Acc. Worst Acc.
ERM 70.3 6.4% 32.8 0.45% 29.9 0.4% 53.8 0.8% 52.1 0.4%
Vanilla mixup 71.2 5.3% 34.2 0.45% 26.5 0.5% 53.3 0.0% 51.3 0.7%
In-group mixup 75.5 6.7% 32.2 1.18% 24.4 0.2% 53.8 0.6% 52.7 0.5%
LISA (ours) 77.1 6.5% 35.5 0.65% 31.9 0.8% 54.7 0.0% 54.2 0.7%
Table 3: Compared LISA with substitute mixup strategies in domain shifts.
CMNIST Waterbirds CelebA CivilComments
Avg. Worst Avg. Worst Avg. Worst Avg. Worst
ERM 27.8% 0.0% 97.0% 63.7% 94.9% 47.8% 92.2% 56.0%
Vanilla mixup 32.6% 3.1% 81.0% 56.2% 95.8% 46.4% 90.8% 67.2%
Vanilla mixup + UW 72.2% 71.8% 92.1% 85.6% 91.5% 88.0% 87.8% 66.1%
In-group mixup 33.6% 24.0% 88.7% 68.0% 95.2% 58.3% 90.8% 69.2%
In-group mixup + UW 72.6% 71.6% 91.4% 87.1% 92.4% 87.8% 84.8% 69.3%
LISA (ours) 74.0% 73.3% 91.8% 89.2% 92.4% 89.3% 89.2% 72.6%
Table 4: Compared LISA with substitute mixup strategies in subpopulation shifts. UW represents upweighting. Full results with standard deviation is listed in Table 16.

In LISA, we apply selective augmentation strategies on samples either with the same label but different domains or with the same domain but different labels. Here, we explore two substitute interpolation strategies: (1) Vanilla mixup: In Vanilla mixup, we do not add any constraint on the sample selection, i.e., the mixup is performed on any pairs of samples; (2) In-group mixup: This strategy applies data interpolation on samples with the same labels and from the same domains. Notice that all the substitute interpolation strategies use the same mixup types (e.g., mixup/Manifold Mixup/CutMix) as LISA. Finally, since upweighting (UW) small groups significantly improves performance in subpopulation shifts, we also evaluate UW combined with Vanilla/In-group mixup.

The ablation results of domain shifts and subpopulation shifts are in Table 3 and Table 4, respectively. Furthermore, we also conduct experiments on datasets without spurious correlation in Table 17 of Appendix A.3. From the results, we make the following three key observations. First, compared with Vanilla mixup, the performance of LISA verifies that selective data interpolation does improve the out-of-distribution robustness by canceling out the spurious correlations and encouraging learning invariant functions rather than simply data augmentation. This findings are further strengthened by the results in Table 17 of Appendix A.3, where Vanilla mixup outperforms LISA and ERM without spurious correlations but LISA achieves the best performance with spurious correlations. Second, the superiority of LISA over In-group mixup verifies that only interpolating samples within each group is incapable of eliminating out the spurious information, where In-group mixup still performs the role of data augmentation. Finally, though incorporating UW significantly improves the performance of Vanilla mixup and In-group mixup in subpopulation shifts, LISA still achieves larger benefits than these enhanced substitute strategies, demonstrating its stronger power in improving OOD robustness.

4.4 Effect of the Degree of Distribution Shifts

We further investigate the performance of LISA with respect to the degree of distribution shifts. Here, we use MetaShift to evaluate performance, where the distance between training and test domains is measured as the node similarity on a meta-graph (Liang & Zou, 2021). To vary the distance between training and test domains, we change the backgrounds of training objects (see full experimental details in Appendix A.1.1). The performance with varied distances is illustrated in Table 5, where the top four best methods (i.e., ERM, GroupDRO, IRM, DomainMix) are reported for comparison. We observe that LISA consistently outperforms other methods under all scenarios. Additionally, another interesting finding is that LISA achieves more substantial improvements with the increases of distance. A potential reason is that the models may rely more heavily on domain correlations when there is a larger distance between training and test domains.

Distance 0.44 0.71 1.12 1.43
ERM 80.1% 68.4% 52.1% 33.2%
IRM 79.5% 67.4% 51.8% 32.0%
DomainMix 76.0% 63.7% 51.3% 30.8%
GroupDRO 77.0% 68.9% 51.9% 34.2%
LISA (ours) 81.3% 69.7% 54.2% 37.5%
Table 5: Effects of the degree of distribution shifts w.r.t. the performance. Distance represents the distribution distance between training and test domains.

4.5 Analysis about Learned Invariance

CMNIST Waterbirds MetaShift CMNIST Waterbirds MetaShift
ERM 12.0486 0.2456 1.8824 6.2858 1.8878 1.2051
Vanilla mixup 0.2769 0.1465 0.2659 4.7365 2.9121 1.1711
IRM 0.0112 0.1243 0.8748 7.7549 1.1219 1.1483
DomainMix 0.1674 0.0995 1.1158 5.1600 2.7467 1.2019
LISA (ours) 0.0012 0.0016 0.2387 0.5670 0.1344 1.0005
Table 6:

Results of the analysis of function-level invariance. Variance of test risks across all domains (

) and pairwise KL divergence of prediction among all domains () are used to measure the invariance. Smaller values denote stronger invariance.

Finally, we analyze the function-level invariance learned by LISA. Specifically, we use two metrics to measure the function-level invariance: (1) Variance of test risks (). Motivated by Arjovsky et al. (2019); Krueger et al. (2021), the function-level invariance is first measured by the variance of test risks across all domains, which is defined as , where represents the number of test domains and represents the risk of domain ; (2) Pairwise divergence of prediction (). We further measure the KL divergence of the predicted probability among all domains, i.e., . Small values of and represent strong function-level invariance. The results on CMNIST, Waterbirds and MetaShift are reported in Table 6, where the superiority of LISA verifies that it could also improve the function-level invariance. Besides the function-level invariance, we observe that LISA can also leads to stronger representation-level invariance, which are detailed in Appendix A.4.

5 Theoretical Analysis

In this section, we provide some theoretical understandings that explain several of the empirical phenomena from the previous experiments and theoretically compare the worst-group errors of three methods: the proposed LISA, ERM, and vanilla mixup. Specifically, we consider a Gaussian mixture model with subpopulation and domain shifts, which has been widely adopted in theory to shed light upon complex machine learning phenomenon such as in

(Montanari et al., 2019; Zhang et al., 2021c). We also note here that despite the popularity of mixup in practice, the theoretical analysis of how mixup (with or without the selective augmentation strategies) affects the misclassification error is still largely unexplored in the literature even in the simple models. As discussed in Section 2, here, we define as the label, and as the domain information. For and , we consider the following model:

(4)

where is the conditional mean vector and is the covariance matrix. Let . Let , , and .

To account for the spurious correlation brought by domains, we consider in general for and the imbalanced case where . Moreover, we assume there exists some invariance across different domains. Specifically, we assume

According to the theory of Fisher’s linear discriminant analysis (Anderson, 1962), the optimal classification rule is linear with slope . The assumption above implies that is the (unknown) invariant prediction rule for model equation 4.

Suppose we use some method and obtain a linear classifier from a training data, we will apply it to a test data and compute the worst-group misclassification error, where the mis-classification error for domain and class is , and we denote the worst-group error with the method as

where and are the slope and intercept based on the method . Specifically, denotes the ERM method (by minimizing the sum of squares loss on the training data altogether), denotes the vanilla mixup method (without any selective augmentation strategy), and denotes the mixup strategy for LISA. We also denote its finite sample version by .

Let denote the marginal difference and denote the correlation operator between the domain-specific difference and the marginal difference with respect to . We see that smaller indicates larger discrepancy between the marginal difference and the domain-specific difference and therefore implies stronger spurious correlation between the domains and labels. We present the following theorem showing that our proposed LISA algorithm outperforms the ERM and vanilla mixup in the subpopulation shifts setting.

Theorem 1 (Error comparison with subpopulation shifts)

Consider independent samples generated from model (4), , , , and is positive definite. Suppose satisfies that for some large enough constant and . Then for any ,

Theorem 1 implies that when is small (indicating that the domain has strong spurious correlation with the label) and , the worst-group classification errors of LISA are asymptotically smaller than that of ERM and vanilla mixup. In fact, our analysis shows that LISA yields a classification rule closer to the invariant classification rules by leveraging the domain information.

In the next theorem, we present the mis-classification error comparisons with domain shifts. That is, consider samples from a new unseen domain:

Let , where is the mean of the training distribution, and assume . Let and denote the correlation for and for , respectively, with respect to . Let and its sample version be .

Theorem 2 (Error comparison with domain shifts)

Suppose samples are independently generated from model (4), , and is positive definite. Suppose that satisfy that and for some large enough constant and . Then for any ,

Similar to Theorem 1, this result shows that when domain has strong spurious correlation with the label (corresponding to small ), such a spurious correlation leads to the downgraded performance of ERM and vanilla mixup, while our proposed LISA method is able to mitigate such an issue by selective data interpolation. Proofs of Theorem 1 and Theorem 2 are provided in Appendix B.

6 Related Work and Discussion

In this paper, we focus on improving the robustness of machine learning models to domain shifts and subpopulation shifts. Here, we discuss related approaches from the following three categories:

Learning Invariant Representations. Motivated by unsupervised domain adaptation (Ben-David et al., 2010; Ganin et al., 2016), the first category of works learns invariant representations by aligning representations across domains. The major research line of this category aims to eliminate the domain dependency by minimizing the divergence of feature distributions with different distance metrics, e.g., maximum mean discrepancy (Tzeng et al., 2014; Long et al., 2015), an adversarial loss (Ganin et al., 2016; Li et al., 2018), Wassertein distance (Zhou et al., 2020a). Follow-up works applied data augmentation to (1) generate more domains and enhance the consistency of representations during training (Yue et al., 2019; Zhou et al., 2020b; Xu et al., 2020; Yan et al., 2020; Shu et al., 2021; Wang et al., 2020) or (2) generate new domains in an adversarial way to imitate the challenging domains without using training domain information (Zhao et al., 2020; Qiao et al., 2020; Volpi et al., 2018). Unlike these latter methods, LISA instead focuses on learning invariant functions without restricting the representations, leading to stronger empirical performance.

Learning Invariant Predictors. Beyond using domain alignment to learning invariant representations, recent work aims to further enhance the correlations between the invariant representations and the labels (Koyama & Yamaguchi, 2020), leading to invariant predictors. Representatively, motivated by casual inference, invariant risk minimization (IRM) (Arjovsky et al., 2019) aims to find a predictor that performs well across all domains. After IRM, the following works propose stronger regularizers by penalizing the variance of risks across all domains (Krueger et al., 2021), by aligning the gradient across domains (Koyama & Yamaguchi, 2020), by smoothing the cross-domain interpolation paths (Chuang & Mroueh, 2021), or through game-theoretic invariant rationalization criterion (Chang et al., 2020). Instead of using regularization, LISA eliminates spurious correlations and learns invariant functions in the data directly via data interpolation.

Group Robustness. The last category of methods combats spurious correlations and are particularly suitable for subpopulation shifts. These approaches include directly optimizing the worst-group performance with Distributionally Robust Optimization (Sagawa et al., 2020a; Zhang et al., 2021a; Zhou et al., 2021), generating samples around the minority groups (Goel et al., 2021), and balancing the majority and minority groups via reweighting (Sagawa et al., 2020b) or regularizing (Cao et al., 2019, 2020). Here, LISA proposes a more general strategy based that is suitable for both domain shifts and subpopulation shifts.

7 Conclusion

To tackle the distribution shifts, we propose LISA, a simple and efficient algorithm, to improve the out-of-distribution robustness. LISA aims to eliminate the domain-related spurious correlations among the training set by selective sample interpolation. We evaluate the effectiveness of LISA on nine datasets under subpopulation shifts and domain shifts settings, demonstrating its promise. Besides, our detailed analysis verifies that the performance gains caused by LISA result from encouraging learning invariant functions and representations. Our theoretical results further strengthen the superiority of LISA by showing smaller worst-group mis-classification error compared with ERM and vanilla data interpolation.

References

Appendix A Additional Experiments

a.1 Domain Shifts

a.1.1 Dataset Details

Datasets Domains Metric Base Model Num. of classes
Camelyon17 5 hospitals Avg. Acc. DenseNet-121 2
FMoW 16 years x 5 regions Worst-group Acc. DenseNet-121 62
RxRx1 51 experimental batches Avg. Acc. ResNet-50 1,139
Amazon 7,676 reviewers 10th Percentile Acc. DistilBERT-uncased 5
MetaShift 4 backgrounds Worst-group Acc. ResNet-50 2
Table 7: Dataset Statistics for Domain Shifts.

In this section, we provide detailed descriptions of datasets used in the experiments of domain shifts and report the data statistics in Table 7.

Camelyon17

We use Camelyon17 from the WILDS benchmark (Koh et al., 2021; Bandi et al., 2018), which provides lymph-node scans sampled from hospitals. Camelyon17 is a medical image classification task where the input is a image and the label is whether there exists tumor tissue in the image. The domain denotes the hospital that the patch was taken from. The training dataset is drawn from the first hospitals, while out-of-distribution validation and out-of-distribution test datasets are sampled from the -th hospital and -th hospital respectively.

FMoW

The FMoW dataset is from the WILDS benchmark (Koh et al., 2021; Christie et al., 2018) — a satellite image classification task which includes classes and domains ( years x regions). Concretely, the input is a RGB satellite image, the label is one of the building or land use categories, and the domain represents the year that the image was taken as well as its corresponding geographical region – Africa, the Americas, Oceania, Asia, or Europe. The train/test/validation splits are based on the time when the images are taken. Specifically, images taken before 2013 are used as the training set. Images taken between 2013 and 2015 are used as the validation set. Images taken after 2015 are used for testing.

RxRx1

RxRx1 (Koh et al., 2021; Taylor et al., 2019) from the WILDS benchmark is a cell image classification task. In the dataset, some cells have been genetically perturbed by siRNA. The goal of RxRx1 is to predict which siRNA that the cells have been treated with. Concretely, the input is an image of cells obtained by fluorescent microscopy, the label indicates which of the genetic treatments the cells received, and the domain denotes the experimental batches. Here, different batches of images are used for training, where each batch contains one sample for each class. The out-of-distribution validation set has images from experimental batches. The out-of-distribution test set has experimental batches. The average accuracy on out-of-distribution test set is reported.

Amazon

Each task in the Amazon benchmark (Koh et al., 2021; Ni et al., 2019) is a multi-class sentiment classification task. The input is the text of a review, the label is the corresponding star rating ranging from 1 to 5, and the domain is the corresponding reviewer. The training set has reviews from reviewers, while the out-of-distribution validation set has reviews from another reviewers. The out-of-distribution test set also has reviews from the rest reviewers. We evaluate the models by the 10th percentile of per-user accuracies in the test set.

MetaShift

We use the MetaShift (Liang & Zou, 2021), which is derived from Visual Genome (Krishna et al., 2016). MetaShift leverages the natural heterogeneity of Visual Genome to provide many distinct data distributions for a given class (e.g. “cats with cars” or “cats in bathroom” for the “cat” class). A key feature of MetaShift is that it provides explicit explanations of the dataset correlation and a distance score to measure the degree of distribution shift between any pair of sets.

We adopt the “Cat vs. Dog” task in MetaShift, where we evaluate the model on the “dog(shelf)” domain with 306 images, and the “cat(shelf)” domain with 235 images. The training data for the “Cat” class is the cat(sofa + bed), including cat(sofa) domain and cat(bed) domain. MetaShift provides 4 different sets of training data for the “Dog” class in an increasingly challenging order, i.e., increasing the amount of distribution shift. Specifically, dog(cabinet + bed), dog(bag + box), dog(bench + bike), dog(boat + surfboard) are selected for training, where their corresponding distances to dog(shelf) are 0.44, 0.71, 1.12, 1.43.

a.1.2 Training Details

Follow WILDS Koh et al. (2021), we adopt pre-trained DenseNet121 (Huang et al., 2017) for Camelyon17 and FMoW datasets, ResNet-50 (He et al., 2016) for RxRx1 and MetaShift datasets, and DistilBert (Sanh et al., 2019) for Amazon datasets.

In each training iteration, we first draw a batch of samples from the training set. With , we then select another sample batch with same labels as for data interpolation. The interpolation ratio is drawn from the distribution . We use the same image transformers as Koh et al. (2021), and all other hyperparameter settings are listed in Table 8.

Dataset Camelyon17 FMoW RxRx1 Amazon MetaShift
Learning rate 1e-4 1e-4 1e-3 2e-6 1e-3
Weight decay 0 0 1e-5 0 1e-4
Scheduler n/a n/a Cosine Warmup n/a n/a
Batch size 32 32 72 8 16
Type of mixup CutMix CutMix CutMix ManifoldMix CutMix
Architecture DenseNet121 DenseNet121 ResNet50 DistilBert ResNet50
Optimizer SGD Adam Adam Adam SGD

Maximum Epoch

2 5 90 3 100
Strategy sel. prob. 1.0 1.0 1.0 1.0 1.0
Table 8: Hyperparameter settings for the domain shifts.

a.1.3 Full Results of WILDS data

Follow Koh et al. (2021), we reported more results on WILDS datasets in Table 9 - Table 12, including validation performance and the results of other metrics. According to these additional results, we could see that LISA outperforms other baseline approaches in all scenarios. Particularly, we here discuss two additional findings: (1) In Camelyon dataset, the test data is much more visually distinctive compared with the validation data, resulting in the large gap () between validation and test performance of ERM (see Table 9). However, LISA significantly reduces the performance gap between the validation and test sets, showing its promise in improving OOD robustness; (2) In Amazon dataset, though LISA performs worse than ERM in average accuracy, it achieves the best accuracy at the 10th percentile, which is regarded as a more common and important metric to evaluate whether models perform consistently well across all users (Koh et al., 2021).

Validation Acc. Test Acc.
ERM 84.9 3.1% 70.3 6.4%
IRM 86.2 1.4% 64.2 8.1%
Coral 86.2 1.4% 59.5 7.7%
GroupDRO 85.5 2.2% 68.4 7.3%
DomainMix 83.5 1.1% 69.7 5.5%
Fish 83.9 1.2% 74.7 7.1%
LISA (ours) 81.8 1.3% 77.1 6.5%
Table 9: Full Results of Camelyon17. We report both validation accuracy and test accuracy.
Validation Test
Avg. Acc. Worst Acc. Avg. Acc. Worst Acc.
ERM 59.5 0.37% 48.9 0.62% 53.0 0.55% 32.3 1.25%
IRM 57.4 0.37% 47.5 1.57% 50.8 0.13% 30.0 1.37%
Coral 56.9 0.25% 47.1 0.43% 50.5 0.36% 31.7 1.24%
GroupDRO 58.8 0.19% 46.5 0.25% 52.1 0.50% 30.8 0.81%
DomainMix 58.6 0.29% 48.9 1.15% 51.6 0.19% 34.2 0.76%
Fish 57.8 0.15% 49.5 2.34% 51.8 0.32% 34.6 0.18%
LISA (ours) 58.7 0.92% 48.7 0.74% 52.8 0.94% 35.5 0.65%
Table 10: Full Results of FMoW. Here, we report the average accuracy and the worst-domain accuracy on both validation and test sets.
Validation Acc. Test ID Acc. Test OOD Acc.
ERM 19.4 0.2% 35.9 0.4% 29.9 0.4%
IRM 5.6 0.4% 9.9 1.4% 8.2 1.1%
Coral 18.5 0.4% 34.0 0.3% 28.4 0.3%
GroupDRO 15.2 0.1% 28.1 0.3% 23.0 0.3%
DomainMix 19.3 0.7% 39.8 0.2% 30.8 0.4%
Fish 7.5 0.6% 12.7 1.9% 10.1 1.5%
LISA (ours) 20.1 0.4% 41.2 1.0% 31.9 0.8%
Table 11: Full Results of RxRx1. ID: in-distribution; OOD: out-of-distribution
Validation Test
Avg. Acc. 10-th Per. Avg. Acc. 10-th Per. Acc.
ERM 72.7 0.1% 55.2 0.7% 71.9 0.1% 53.8 0.8%
IRM 71.5 0.3% 54.2 0.8% 70.5 0.3% 52.4 0.8%
Coral 72.0 0.3% 54.7 0.0% 70.0 0.6% 52.9 0.8%
GroupDRO 70.7 0.6% 54.7 0.0% 70.0 0.6% 53.3 0.0%
DomainMix 71.9 0.2% 54.7 0.0% 71.1 0.1% 53.3 0.0%
Fish 72.5 0.0% 54.7 0.0% 71.7 0.1% 53.3 0.0%
LISA (ours) 71.2 0.3% 55.1 0.61% 70.6 0.3% 54.7 0.0%
Table 12: Full Results of Amazon. Both the average accuracy and the 10th Percentile accuracy are reported.

a.2 Subpopulation Shifts

a.2.1 Dataset Details

We detail the data descriptions of subpopulation shifts below and report the detailed data statistics in Table 13.

Colored MNIST (CMNIST): We classify MNIST digits from 2 classes, where classes 0 and 1 indicate original digits (0,1,2,3,4) and (5,6,7,8,9). The color is treated as a spurious attribute. Concretely, in the training set, the proportion between red samples and green samples is 8:2 in class 0, while the proportion is set as 2:8 in class 1. In the validation set, the proportion between green and red samples is 1:1 for all classes. In the test set, the proportion between green and red samples is 1:9 in class 0, while the ratio is 9:1 in class 1. The data sizes of train, validation, and test sets are 30000, 10000, and 20000, respectively.

Waterbirds (Sagawa et al., 2020a): The Waterbirds dataset aims to classify birds as “waterbird” or “landbird”, where each bird image is spuriously associated with the background “water” or “land”. Waterbirds is a synthetic dataset where each image is composed by pasting a bird image sampled from CUB dataset (Wah et al., 2011) to a background drawn from the Places dataset (Zhou et al., 2017). The bird categories in CUB are stratified as land birds or water birds. Specifically, the following bird species are selected to construct the waterbird class: albatross, auklet, cormorant, frigatebird, fulmar, gull, jaeger, kittiwake, pelican, puffin, tern, gadwall, grebe, mallard, merganser, guillemot, or Pacific loon. All other bird species are combined as the landbird class. We define (land background, waterbird) and (water background, landbird) are minority groups. There are 4,795 training samples while only 56 samples are “waterbirds on land” and 184 samples are “landbirds on water”. The remaining training data include 3,498 samples from “landbirds on land”, and 1,057 samples from “waterbirds on water”.

CelebA (Liu et al., 2015; Sagawa et al., 2020a): For the CelebA data (Liu et al., 2015), we follow the data preprocess procedure from (Sagawa et al., 2020a). CelebA defines a image classification task where the input is a face image of celebrities and the classification label is its corresponding hair color – “blond” or “not blond.” The label is spuriously correlated with gender, i.e., male or female. In CelebA, the minority groups are (blond, male) and (not blond, female). The number of samples for each group are 71,629 “dark hair, female”, 66,874 “dark hair, male”, 22,880 “blond hair, female”, 1,387 “blond hair, male”.

CivilComments (Borkan et al., 2019; Koh et al., 2021): We use CivilComments from the WILDS benchmark (Koh et al., 2021). CivilComments is a text classification task, aiming to predict whether an online comment is toxic or non-toxic. The spurious domain identifications are defined as the demographic features, including male, female, LGBTQ, Christian, Muslim, other religion, Black, and White. CivilComments contains 450,000 comments collected from online articles. The number of samples for training, validation, and test are 269,038, 45,180, and 133,782, respectively. The readers may kindly refer to Table 17 in (Koh et al., 2021) for the detailed group information.

Datasets Domains Base Model Class Information
CMNIST 2 digit colors ResNet-50 digit (0,1,2,3,4) v.s. (5,6,7,8,9)
Waterbirds 2 backgrounds ResNet-50 waterbirds v.s. landbirds
CelebA 2 hair colors ResNet-50 man v.s. women
CivilComments 8 demographic identities DistilBERT-uncased toxic v.s. non-toxic
Table 13: Dataset Statistics for Subpopulation Shifts. All datasets are binary classification tasks and we use the worst group accuracy as the evaluation metric.

a.2.2 Training Details

We adopt pre-trained ResNet-50 (He et al., 2016) and BERT (Sanh et al., 2019) as the model for image data (i.e., CMNIST, Waterbirds, CelebA) and text data (i.e., CivilComments), respectively. In each training iteration, we sample a batch of data per group. For intra-label LISA, we randomly apply mixup on sample batches with the same labels but different domains. For intra-domain LISA, we instead apply mixup on sample batches with the same domain but different labels. The interpolation ratio is sampled from the distribution . All hyperparameters are listed in Table 14.

a.2.3 Additional Results

In this section, we have added the full results of subpopulation shifts in Table 15 and Table 16.

Dataset CMNIST Waterbirds CelebA CivilComments
Learning rate 1e-3 1e-3 1e-4 1e-5
Weight decay 1e-4 1e-4 1e-4 0
Scheduler n/a n/a n/a n/a
Batch size 16 16 16 8
Type of mixup mixup mixup CutMix ManifoldMix
Architecture ResNet50 ResNet50 ResNet50 DistilBert
Optimizer SGD SGD SGD Adam
Maximum Epoch 300 300 50 3
Strategy sel. prob. 0.5 0.5 0.5 1.0
Table 14: Hyperparameter settings for the subpopulation shifts.
CMNIST Waterbirds
Avg. Worst Avg. Worst
ERM 27.8 1.9% 0.0 0.0% 97.0 0.2% 63.7 1.9%
UW 72.2 1.1% 66.0 0.7% 95.1 0.3% 88.0 1.3%
IRM 72.1 1.2% 70.3 0.8% 87.5 0.7% 75.6 3.1%
Coral 71.8 1.7% 69.5 0.9% 90.3 1.1% 79.8 1.8%
GroupDRO 72.3 1.2% 68.6 0.8% 91.8 0.3% 90.6 1.1%
DomainMix 51.4 1.3% 48.0 1.3% 76.4 0.3% 53.0 1.3%
Fish 46.9 1.4% 35.6 1.7% 85.6 0.4% 64.0 0.3%
LISA 74.0 0.1% 73.3 0.2% 91.8 0.3% 89.2 0.6%
CelebA CivilComments
Avg. Worst Avg. Worst
ERM 94.9 0.2% 47.8 3.7% 92.2 0.1% 56.0 3.6%
UW 92.9 0.2% 83.3 2.8% 89.8 0.5% 69.2 0.9%
IRM 94.0 0.4% 77.8 3.9% 88.8 0.7% 66.3 2.1%
Coral 93.8 0.3% 76.9 3.6% 88.7 0.5% 65.6 1.3%
GroupDRO 92.1 0.4% 87.2 1.6% 89.9 0.5% 70.0 2.0%
DomainMix 93.4 0.1% 65.6 1.7% 90.9 0.4% 63.6 2.5%
Fish 93.1 0.3% 61.2 2.5% 89.8 0.4% 71.1 0.4%
LISA (ours) 92.4 0.4% 89.3 1.1% 89.2 0.9% 72.6 0.1%
Table 15: Full results of subpopulation shifts with standard deviation. All the results are performed with three random seed.
CMNIST Waterbirds
Avg. Worst Avg. Worst
ERM 27.8 1.9% 0.0 0.0% 97.0 0.2% 63.7 1.9%
Vanilla mixup 32.6 3.1% 3.1 2.4% 81.0 0.2% 56.2 0.2%
Vanilla mixup + UW 72.2 0.7% 71.8 0.1% 92.1 0.1% 85.6 1.0%
In-group Group 33.6 1.9% 24.0 1.1% 88.7 0.3% 68.0 0.4%
In-group + UW 72.6 0.1% 71.6 0.2% 91.4 0.6% 87.1 0.6%
LISA (ours) 74.0 0.1% 73.3 0.2% 91.8 0.3% 89.2 0.6%
CelebA CivilComments
Avg. Worst Avg. Worst
ERM 94.9 0.2% 47.8 3.7% 92.2 0.1% 56.0 3.6%
Vanilla mixup 95.8 0.0% 46.4 0.5% 90.8 0.8% 67.2 1.2%
Vanilla mixup + UW 91.5 0.2% 88.0 0.3% 87.8 1.2% 66.1 1.4%
Within Group 95.2 0.3% 58.3 0.9% 90.8 0.6% 69.2 0.8%
Within Group + UW 92.4 0.4% 87.8 0.6% 84.8 0.7% 69.3 1.1%
LISA (ours) 92.4 0.4% 89.3 1.1% 89.2 0.9% 72.6 0.1%
Table 16: Full table of the comparison between LISA and other substitute mixup strategies in subpopulation shifts. UW represents upweighting.

a.3 Results on Datasets without Spurious Correlations

In order to analyze the factors that lead to the performance gains of LISA, we conduct experiments on datasets without spurious correlations. To be more specific, we balance the number of samples for each group under the subpopulation shifts setting. The results of ERM, Vanilla mixup and LISA on CMNIST, Waterbirds and CelebA are reported in Table 17. The results show that LISA performs similarly compared with ERM when datasets do not have spurious correlations. If there exists any spurious correlation, LISA significantly outperforms ERM. Another interesting finding is that Vanilla mixup outperforms LISA and ERM without spurious correlations, while LISA achieves the best performance with spurious correlations. This finding strengthens our conclusion that the performance gains of LISA are caused by eliminating spurious correlations rather than data augmentation.

Dataset ERM Vanilla mixup LISA
CMNIST 73.67% 74.28% 73.18%
Waterbirds 88.07% 88.23% 87.05%
CelebA 86.11% 88.89% 87.22%
Table 17: Results on Datasets without Spurious Correlations

a.4 Analysis of Representation-level Invariance

For each label

, assume the hidden representation for each domain

as . Follow the analysis of function-level invariance, the representation-level invariance is measured by the pairwise KL divergence of distribution among all domains as , where smaller values indicate that the learned representations are more invariant with respect to the labels. We report the results on CMNIST, Waterbirds and MetaShift in Table 18. Our key observations are: (1) Compared with ERM, LISA learns stronger representation-level invariance. The potential reason is that function-level invariance leads to better representation-level invariance. (2) LISA has greater invariance than vanilla mixup, validating that the invariant representations are not caused by naive data interpolation. (3) LISA provides more invariant representations than regularization-based methods, i.e., IRM and DomainMix.

CMNIST Waterbirds MetaShift
ERM 1.683 3.592 0.632
Vanilla mixup 4.392 3.935 0.634
IRM 1.905 2.413 0.627
DomainMix 2.155 3.716 0.614
LISA (ours) 0.421 1.912 0.585
Table 18: Results of representation-level invariance ( for CMNIST and MetaShift), where smaller values denote stronger invariance.

Besides the quantitative analysis, follow Appendix C in (Lee et al., 2019), we visualize the hidden representations for all test samples and the decision boundary on Waterbirds in Figure 2. Compared with other methods, the representations of samples with the same label that learned by LISA are closer regardless of their domain information, which further demonstrates the promise of LISA in learning more invariant representations.

Figure 2: Visualization of sample representations and decision boundaries on Waterbirds dataset.

Appendix B Proofs of Theorem 1 and Theorem 2

Outline of the proof. We will first find the mis-classification errors based on the population version of OLS with different mixup strategies. Next, we will develop the convergence rate of the empirical OLS based on samples towards its population version. These two steps together give us the empirical mis-classification errors of different methods. We will separately show that the upper bounds in Theorem 1 and Theorem 2 hold for two selective augmentation strategies of LISA and hence hold for any . Let LL denote intra-label LISA and LD denote intra-domain LISA.

Let and denote the marginal class proportions in the training samples. Let and denote the marginal subpopulation proportions in the training samples. Let and define , , and similarly.

We consider the setting where is relatively small and .

b.1 Decomposing the loss function

Recall that . We further define , , and .

For the mixup estimators, we will repeatedly use the fact that

has a symmetric distribution with support .

For ERM estimator based on , where , we have

Notice that based on the estimator

b.2 Classification errors of four methods with infinite training samples

We first provide the limit of the classification errors when .

b.2.1 Baseline method: ERM

For the training data, it is easy to show that

For and , the ERM has slope and intercept being

In the extreme case where , we have

Hence,

(5)

where is computed via ERM.

b.2.2 Baseline method: Vanilla mixup

The vanilla mixup does not use the group information. Let be a random draw from . Let be a random draw from independent of . Let

and

We can find that