Learning Representations that Support Robust Transfer of Predictors

10/19/2021
by   Yilun Xu, et al.
MIT
0

Ensuring generalization to unseen environments remains a challenge. Domain shift can lead to substantially degraded performance unless shifts are well-exercised within the available training environments. We introduce a simple robust estimation criterion – transfer risk – that is specifically geared towards optimizing transfer to new environments. Effectively, the criterion amounts to finding a representation that minimizes the risk of applying any optimal predictor trained on one environment to another. The transfer risk essentially decomposes into two terms, a direct transfer term and a weighted gradient-matching term arising from the optimality of per-environment predictors. Although inspired by IRM, we show that transfer risk serves as a better out-of-distribution generalization criterion, both theoretically and empirically. We further demonstrate the impact of optimizing such transfer risk on two controlled settings, each representing a different pattern of environment shift, as well as on two real-world datasets. Experimentally, the approach outperforms baselines across various out-of-distribution generalization tasks. Code is available at <https://github.com/Newbeeer/TRM>.

READ FULL TEXT VIEW PDF

Authors

page 7

05/13/2021

Causally-motivated Shortcut Removal Using Auxiliary Labels

Robustness to certain distribution shifts is a key requirement in many M...
06/21/2022

Performance Prediction Under Dataset Shift

ML models deployed in production often have to face unknown domain chang...
06/06/2020

Domain Extrapolation via Regret Minimization

Many real prediction tasks such as molecular property prediction require...
07/06/2020

Estimating Generalization under Distribution Shifts via Domain-Invariant Representations

When machine learning models are deployed on a test distribution differe...
11/17/2020

Close Category Generalization

Out-of-distribution generalization is a core challenge in machine learni...
07/20/2021

Characterizing Generalization under Out-Of-Distribution Shifts in Deep Metric Learning

Deep Metric Learning (DML) aims to find representations suitable for zer...
03/02/2020

Out-of-Distribution Generalization via Risk Extrapolation (REx)

Generalizing outside of the training distribution is an open challenge f...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Training and test examples are rarely sampled from the same distribution in real applications. Indeed, training and test scenarios often represent somewhat different domains. Such discrepancies can degrade generalization performance or even cause serious failures, unless specifically mitigated. For example, standard empirical risk minimization approach (ERM) that builds on the notion of matching training and test distributions rely on statistically informative but non-causal features such as textures (Geirhos2019ImageNettrainedCA), background scenes (Beery2018RecognitionIT), or word co-occurrences in sentences (Chang2020InvariantR).

Learning to generalize to domains that are unseen during training is a challenging problem. One approach to domain generalization or out-of-distribution generalization is based on reducing variation due to sets or environments one has access to during training. For example, one can align features of different environments (Muandet2013DomainGV; Sun2016DeepCC) or use data-augmentation to help prevent overfitting to environment-specific features (Carlucci2019DomainGB; Zhou2020DeepDI). At one extreme, domain adaptation assumes access to unlabeled test examples whose distribution can be then matched in the feature space (e.g, (Ganin2016DomainAdversarialTO)).

More recent approaches build on causal invariance as the foundation for out-of-distribution generalization. The key assumption is that the available training environments represent nuisance variation, realized by intervening on non-causal variables in the underlying Structural Causal Model (Pearl2000CausalityMR). Since causal relationships can be assumed to remain invariant across the training environments as well as any unseen environments, a number of recent approaches (Peters2015CausalIU; Arjovsky2019InvariantRM; Krueger2020OutofDistributionGV) tailor their objectives to remove spurious (non-causal) features specific to training environments.

In this paper, we propose a simple robust criterion termed Transfer Risk Minimization (TRM). The goal of TRM is to directly translate model’s ability to generalize across environments into a learning objective. As in prior work, we decompose the model into a feature mapping and a predictor operating on the features. Our transfer risk in this setting measures the average risk of applying the optimal predictor learned in one environment to examples from another adversarially chosen environment. The feature representation is then tailored to support such robust transfer. Although our work is greatly inspired by IRM (Arjovsky2019InvariantRM), we show that TRM serves as a better out-of-distribution criterion with both empirical and theoretical analysis in non-linear case. We further show that the TRM objective decomposes into two terms, direct transfer term and a weighted gradient-matching term with connections to meta-learning. We then propose an alternating updating algorithm for optimizing TRM.

To evaluate robustness we introduce two patterns of environment shifts based on 10C-CMNIST and SceneCOCO datasets. We construct these controlled settings so as to exercise different combinations of invariant and non-causal features, highlighting the impact of non-causal features in the training environments. In the absence of non-causal confounders, we show that all the methods achieve decent out-of-distribution generalization. When non-causal features are present, however, TRM offers greater robustness against biased training environments. We further demonstrate that our approach leads to good performance on the two real-world datasets, PACS and Office-Home.

2 Background and related works

Domain generalization

Machine learning models trained with Empirical Risk Minimization may not perform well in unseen environments where examples are sampled from a distribution different from training. The problem is known as out-of-distribution generalization or domain generalization (Blanchard2011GeneralizingFS; Muandet2013DomainGV). A number of recent approaches have been proposed in this context. We only touch some of them for brevity. A typical approach to out-of-distribution generalization involves (distributionally) aligning training environments (Muandet2013DomainGV; Ganin2016DomainAdversarialTO; Sun2016DeepCC; Li2018DomainGV; Shi2021GradientMF). Related approaches such as Nam2019ReducingDG encourage the model to focus more on shapes via style adversarial learning, adopt data augmentations (Carlucci2019DomainGB; Zhou2020DeepDI) or meta-learning (Li2018LearningTG).

Causal invariance

A recent line of work focuses on promoting invariance as a way to isolate causally meaningful features. Ideally, one would specify a structural equation model (Pearl2000CausalityMR), expressing direct and indirect causes, distinguishing them from spurious, environment specific influences that are unlikely to generalize (Peters2015CausalIU; RojasCarulla2018InvariantMF; Mller2020LearningRM). Invariance serves as a statistically more amenable proxy criterion towards identifying causally relevant features for predictors. Arjovsky2019InvariantRM proposed invariant risk minimization over feature-predictor decompositions. The main idea is that the predictor operating on causal features can be assumed to be simultaneously optimal across training environments. A number of related approaches have been proposed. For example, Krueger2020OutofDistributionGV

uses variance of losses as regularization,

Jin2020DomainEV minimizes the regret loss induced by held-out environments and Parascandolo2020LearningET aligns gradient signs across environments by and-mask.

Distributionally robust optimization (DRO)

DRO specifies a minimax criterion for estimating predictors where an adversary gets to modify the training distribution. The allowed modifications are typically expressed in terms of divergence balls around the training distribution (BenTal2013RobustSO; Duchi2016StatisticsOR; Esfahani2018DatadrivenDR). Closer to our work, Group DRO (Hu2018DoesDR; Sagawa2019DistributionallyRN) defines uncertainty regions in terms of a simplex over (fixed) training groups. Both DRO and Group DRO minimize the worst-case loss of the predictor within the uncertain regions. Unlike these methods, we use a predictor-representation decomposition, and define a regularizer over the representation using a minimax criterion. Moreover, we explicitly measure the risk of a predictor trained in one environment but applied to another.

3 Transfer Risk Minimization

Consider a classification problem from input space (e.g. images) to output space (e.g. labels). We are given training environments , where is the empirical distribution for environment . We decompose our model into two parts: feature extractor , which maps the input to a feature representation, and predictor that operates on the features to realize the final output. We call their concatenation

as a classifier. We use

to denote the cross-entropy loss on a training point . As a shorthand, the expected loss with respect to a distribution is given as . The broader goal is to learn a pair of feature extractor and predictor that minimize the risk on some unseen environment :

As a step towards this goal, we learn a predictively robust model across the available training environments (defined later). While the high level aim here resembles invariant risk minimization (Arjovsky2019InvariantRM), our proposed estimation criterion is based on robustness rather than invariance.

3.1 Estimation criterion

We define group (environment) robustness based on exchangeability of predictors. Specifically, we require that environment-specific predictors generalize also to other training environments. Note that this doesn’t imply that a single predictor is per-environment optimal or invariant as in IRM. Instead, our representation aims to minimize transfer risk across a set of training environments

(1)

where refers to the optimal predictor with respect to distribution . is the convex hull of environment specific distributions, excluding . Unlike methods in the DRO family (BenTal2013RobustSO; Sagawa2019DistributionallyRN) that do not decompose the predictors, the robust estimation criterion here is specifically tailored to measure the goodness of features in terms of their ability to permit generalization across the environments. We will show in later sections that transfer risk (Eq. (1)) indeed ensures better out-of-distribution generalization.

Remark

We introduced transfer risk in Eq. (1) as a “sum-sup” criterion with respect to outer and inner terms. Other possible versions with similar estimation consequences include sum-sum, i.e,

. Note that the criterion still measures whether the feature representation allows a predictor trained in one environment to generalize to another. We expect this version to behave similarly when training environments have comparable noise levels, complexities. However, sum-sum version can be more resistant to environmental outliers.

3.2 Comparison with IRM

IRM (Arjovsky2019InvariantRM) is a popular objective for learning features that are invariant across training environments. Specifically, IRM finds a feature extractor such that the associated predictor is simultaneously optimal for every training environment. In our notation

IRM specifies a more restrictive set of admissible feature extractors than transfer risk. Specifically, per-environment optimal predictors in IRM must agree (contain a common predictor) whereas transfer risk uses the per-environment optimal predictor to guide the representation learning. Due to the difficulty of solving the IRM bi-leveled optimization problem, Arjovsky2019InvariantRM introduced a relaxed objective called IRMv1 where the constraints are replaced by gradient penalties:

(2)

To compare with IRM, we use the theoretical framework in Rosenfeld2020TheRO. For each environment, the data are defined by the following process: the binary label is sampled uniformly from and the environmental features are sampled subsequently from label-conditioned Gaussians:

with . The invariant feature mean remains the same for all environments while non-causal means s vary across environments. The observation is generated as a function of the latent features: , where is a injective function that maps low dimensional features to high dimensional observations .

Theorem 3.3 in Rosenfeld2020TheRO shows that for non-linear , there exists a non-linear classifier that has nearly optimal IRMv1 loss. In addition, it is equivalent to ERM solution on nearly all test points when the non-causal mean in the test environment is sufficiently different from those in training. Below we show that TRM can avoid the failure mode of IRM.

Theorem 1 (Informal).

Under some mild assumptions, there exists a classifier that achieves near-optimal IRMv1 loss (Eq. (2)) and has high transfer risk (Eq. (1)). In addition, for any test environment with a non-causal mean far from those in training, this classifier behaves like an ERM-trained classifier on most fractions of the test distribution.

We defer the formal statement and proof to Appendix A.1. We prove the above theorem by constructing a classifier only using invariant features for prediction on the high-density region but behaving like ERM solution on the tails, which can still have near-optimal IRMv1 loss. However, the per-environment optimal predictors are distinct when using ERM-solution on the tails. The discrepancy in the per-environment optimal predictors leads to large transfer risk.

Figure 1: IRMv1 loss (Left) and transfer risk (Right) versus training iterations for models trained by IRMv1 and TRM on 10C-CMNIST (C) and PACS (P). Figure 2: 2-d scenario of the linear case. TRM drives the non-invariant toward the invariant .

In addition to the theoretical analysis, we provide further empirical analysis to characterize the difference between TRM and IRM. Fig. 1 reports the IRMv1 loss and transfer risk on the 10C-CMNIST (C) and PACS (P) datasets discussed later (section 5). Although IRMv1 solutions achieve small IRMv1 losses, it has significantly higher transfer risks than TRM solutions. Conversely, TRM solutions have slightly lower IRMv1 loss than IRMv1 solutions. Besides, the out-of-distribution test accuracies on 10C-CMNIST / PACS are: (IRMv1),  (ERM) and (TRM). IRMv1 solutions have close performance to ERM solutions, while TRM outperforms others by a large margin. The empirical results support the statement in Theorem 1 that models with near-optimal IRMv1 loss can have large transfer risks and behave like ERM solutions on test environments.

Together, these results suggest that transfer risk is a better criterion than IRMv1 for assessing model’s out-of-distribution generalization. In Fig. 2, we demonstrate the effect of TRM on a a toy 2-d example () with linear . TRM drives the non-invariant feature to the invariant one. We defer details of the analysis in the 2-d case to Appendix A.2.

4 Method

In this section, we discuss how to optimize the TRM objective (Eq. (1)). We first introduce an exponential gradient ascent algorithm for optimizing the inner supremum, and then discuss how to optimize the feature extractor with per-environment optimal predictor. An alternating updating algorithm incorporating these steps is summarized in Algorithm 1.

4.1 Transfer Risk Optimization

Solving the inner sup

Assume different environments with associated densities . Given , we find the corresponding worst-case environment in the inner max of Eq. (1). The search space for is the convex hull of all environment distributions with the exception of : . Since the optimization is over a simplex, the solution can be found exactly by just selecting the worst environment in : . Empirically, we find that updating by gradient ascent instead of selecting the worst environment leads to a more stable training process. This has been observed in related contexts (Sagawa2019DistributionallyRN).

The gradient for is , indicating that the inner supremum simply up-weights the environments with larger losses relative to the predictor . We adopt an exponential gradient ascent algorithm (EG) for the updates:

(3)

where is the learning rate, and the subscript denotes the

th component of the vector.

Updating the feature extractor

Given and the corresponding worst-case environment , we consider here how to update the feature extractor in Eq. (1). Denote the risk of using predictor with data distribution as . Recall that the optimal predictor is an implicit function of the feature extractor , i.e, . In the remainder, we use a shorthand to refer to the predictor , where stands for the stop_gradient operator. In other words, the value of follows but it’s partial derivatives w.r.t are set to zero. It is helpful to distinguish from to clarify what is meant by the different expressions below. Now, the full gradient of the transfer risk w.r.t comprises two terms since also depends on

(4)

We show in Proposition 1 that the implicit gradient can be further simplified as a weighted gradient-matching term.

Proposition 1.

Denote the Hessian as . Suppose the loss is continuously differentiable and is non-singular, then we have:

where , and is treated as a constant vector in the above equation (note the use of stop-gradient version ).

We can interpret the numerator of RHS as a gradient-matching objective: it measures the similarity of gradients depending on whether the distribution over which the loss is measured is or , weighted by the Hessian inverse . It shows that TRM naturally aims to find a representation where the gradients are matched when moving from to (cf.  (Shi2021GradientMF)). By integrating, we can write down an objective function whose gradient with respect to matches Eq. 4:

(5)

Note the use of stop-gradient versions in these expressions. Effectively, the TRM objective for decomposes into two terms: (i) the direct transfer term, which encourages predictor to do well even if the distribution were , and (ii) the weighted gradient-matching term. The second term attempts to match the gradient of in the original environment and worst-case environment by updating the features. The weighted gradient-matching term plays a role analogous to meta-learning (Li2018LearningTG; Shi2021GradientMF) encouraging simultaneous loss descent. Note that the weighted gradient-matching term actually evaluates to zero at the current value of since is set to the per-environment optimal value, but the gradient of this term with respect to is not zero.

4.2 Approximation of Inverse Hessian Vector Product

We denote the number of parameters in as and the gradient as . Dropping the subscript for clarity, the weight gradient matching term in Eq. (5) involves the computation of inverse Hessian vector product . For minibatch data , computing the inverse Hessian requires operations. To avoid heavy computation, we use the similar approach in Agarwal2017SecondOrderSO to get good approximations by Taylor expansion and efficient Hessian-vector product (HVP) (Pearlmutter1994FastEM). Let be the first terms in the Taylor expansion of . Note that . We can solve the corresponding matrix vector product in linear time by recursively computing

with fast HVP. These computation are easy to implement in auto-grad systems like PyTorch 

(Paszke2019PyTorchAI).

4.3 Algorithm

In addition to the TRM objective in Eq. (1), we include standard ERM term for updating the predictor on the top of features. Overall, given the distributions , the per-environment objective for updating the pair consists of three terms:

(6)

where is a hyper-parameter for adjusting the gradient-matching term to have the same gradient magnitude as the other terms. We interleave gradient updates on the model parameters () and the environmental weights , as shown in Algorithm 1.

  Input: Inital model parameters , learning rates . and environment set
  for  to  do
     Randomly pick a environment
     Get the optimal on
     Update the model parameters:  ,
     Update the environmental weights by Eq. (3):  
  end for
Algorithm 1 TRM algorithm

Algorithm 1 updates in an online manner. With some convexity, boundness, and smoothness assumptions, we can prove that the on-line updating has a convergence rate of by using the techniques in Nemirovski2009RobustSA. We defer more discussions to Appendix B.3.1.

5 Experiments

In our experiments, we focus on the out-of-distribution generalization tasks. We first evaluate our method on two synthesized datasets (10C-CMNIST, SceneCOCO). We simulate three kinds of domain shifts by controlled experiments. Next, we evaluate all the methods on two real-world datasets (PACS, Office-Home). We compare TRM with standard empirical risk minimization (ERM), and recent methods developed for out-of-distribution generalization: IRM (Arjovsky2019InvariantRM), REx (Krueger2020OutofDistributionGV), GroupDRO (Sagawa2019DistributionallyRN), MLDG (Li2018LearningTG) and Fish (Shi2021GradientMF). We also use ERM trained with data sampled from the test domain to serve as an upper bound (Oracle).

Experiments in the main body use training-domain validation sets for hyper-parameter selection, which are arguably more practical for out-of-distribution generalization task (ahmed2021systematic; Gulrajani2020InSO; Krueger2020OutofDistributionGV). We defer the results of the test-domain validation set to Appendix C. We also show the efficacy of TRM on group distributional robustness in Appendix C.4.

5.1 Experiments on 10C-CMNIST and SceneCOCO

5.1.1 Datasets

Evaluating the out-of-distribution generalization performance in an unambiguous manner necessitates controlled experiments. We synthesize the data by three latent features: (i) invariant (causal) feature, (ii) non-causal feature, which is spuriously correlated with labels and (iii) the dummy feature, which is not predictive of the labels. We conduct the controlled experiments on two synthetic datasets:

10C(lasses)-C(olored)MNIST

is a more general 10 classes version of the 2-classes ColoredMNSIT (Arjovsky2019InvariantRM). We add the digit colors and background colors to allow for the domain shifts. Specifically, we set the invariant/non-causal/dummy features to digit/digit color/background color respectively. We randomly select ten colors as the digit colors and five colors as the background colors. 10C-CMNIST contains 60000 datapoints of dimension (3,28,28) from 10 digit classes.

SceneCOCO

superimposes the objects from the COCO datasets 

(Lin2014MicrosoftCC) on the background scenes from the Places datasets (Zhou2018PlacesA1). Following ahmed2021systematic, we select 10 objects and 10 scenes from above two datasets. We set the invariant/non-causal/dummy features to object/background scene/object color. This dataset consists of 10000 datapoints of dimension from 10 object classes.

In addition, we define a measurement of the correlations between the label and the non-causal features. Note that there is a one-to-one corresponding between the label and the non-causal features, e.g, “2” “blue digit color” in 10C-CMNIST and “boat” “beach scene” in SceneCOCO. For each environment, we define the bias degree to be the ratio of the data that obeys this relationship. Those data which don’t follow this relationship are then assigned with random non-causal features. In each training environment, the data is generated by environmental-specific combination of features and bias degree. This setting is commonly adopted in existing literature (Arjovsky2019InvariantRM; Krueger2020OutofDistributionGV; ahmed2021systematic). The label is set as the class where the invariant feature lies.

5.1.2 Controlled Scenarios

(a) Label-correlated shift
(b) Combined shift
Figure 3: Visualization of the training environment 1 (top) and the test environment (bottom) of 10C-CMNIST and SceneCOCO on (a) Label-correlated shift and (b) Combined shift.
(a) Label-correlated shift
(b) Combined shift
Figure 4: Test accuracy on 10C-CMNIST and SceneCOCO datasets in (a) Label-correlated shift and (b) Combined shift, using various bias degrees .

Next, we consider two scenarios with distinct combinations of latent features.

Label-correlated shift

In this scenario, each training environment is assigned with a different non-zero bias degree. The dummy feature is set to a constant, e.g. black background color in 10C-CMNIST. The bias degree is set to zero in the test environment for evaluating how much extent the model has learn the invariant feature.

Combined shift

In this scenario, training environments have varying non-zero bias degrees and prior distributions of dummy feature. The test environment is unbiased. It simulates the joint effects of the shifts of non-causal and dummy features.

Fig. 3 visualizes the label-correlated and combined shifts. We use two training environments in the controlled experiments. To better evaluate the robustness of algorithms, we vary the bias degrees in the second training environment. We use / bias degrees in 10C-CMNIST for the first/second training environment respectively, and / in SceneCOCO. is ranging from to in 10C-CMNIST and to in SceneCOCO. We use different configurations for the two datasets because SceceCOCO is more complex. The biased degree is in the test environment.

We adopt a 4-layer CNN/Wide ResNet (Zagoruyko2016WideRN) as the feature extractor for 10C-CMNIST/SceneCOCO, following prior work (ahmed2021systematic)

. We train for 10/100 epochs on 10C-CMNIST/SceneCOCO, both using batch size 128 and SGD with

initial learning rate and 0.9 momentum. For more details of datasets, hyper-parameter selection and training, please refer to Appendix B.

5.1.3 Results

In Fig. 4, we report the accuracy on the test environment under label-correlated shift and combined shift. The x-axis in Fig. 4 stands for the varying biased degree of the second training environment. We observe a consistent performance drop of all the methods as the training environments become more biased. Our main finding is that the TRM algorithm achieves a better test accuracy at most bias degrees on both datasets and competitive performance with Fish and MLDG on SceneCOCO when the bias degrees are large. The results show that the model trained by the TRM algorithm depends more on the invariant features to make predictions.

We also observe that non-causal features combined with the distribution shifts of dummy features degrade the performance of all the methods (Combined shift). In Appendix C.1.3, we show that all the methods have similar performance to the Oracle when only changing the dummy feature distribution. The experiments suggest that when non-causal features exist, distribution shifts on dummy features can further hurt the out-of-distribution generalization. Besides, we show in Appendix C.1.2 that TRM-trained models transfer faster to target environments with limited data for fine-tuning.

5.2 Experiments on PACS and Office-Home

5.2.1 Setups

We further evaluate our methods on two real-world datasets, PACS and Office-Home.

PACS (Li2017DeeperBA) comprises four environmental data, namely arts, cartoons, photos, and sketches. This dataset contains 9991 datapoints of dimension (3,224,224) from 7 classes.

Office-Home (Venkateswara2017DeepHN) includes four environments, namely art, clipart, product and real. It contains 15588 datapoints of dimension (3,224,224) from 65 classes.

These two datasets are widely used in domain generalization literature. We follow the standard valuation protocol (Li2017DeeperBA)

, which reports the test accuracy on each hold-out environment when training on the other three environments. We use the ImageNet-pretrained ResNet18 as the backbone of feature extractors. We use the SGD optimizer with a momentum of

, a weight decay of 1e-4, and a fixed learning rate of 1e-4. The batch size is set to .

5.2.2 Results

In Table 1 and 2, we report the test accuracy on PACS and Office-Home dataset. The proposed TRM algorithm achieves superior average accuracy on these datasets. We observe that TRM has better generalization ability on most hold-out environments in the two datasets, except the Photo environment, where all the methods have comparable performance. Further, we test TRM without the weighted gradient-matching term (TRM w/ GM) by setting . TRM w/ GM still outperforms other baselines with only the direct transfer term. We also show in Appendix C.2 that TRM consistently improves over other methods with different architectures and validation set configurations.

Algorithm Art Cartoon Photo Sketch Average
ERM
IRM
REx
GroupDRO 94.7 0.7
MLDG 95.0 0.4
Fish 94.7 0.5
TRM w/ GM 76.9
TRM
Table 1: Test accuracy on PACS dataset
Algorithm Art Clipart Product Real Average
ERM
IRM
REx
GroupDRO
MLDG
Fish
TRM w/ GM
TRM
Table 2: Test accuracy on Office-Home dataset

6 Conclusion

The discrepancy between the training and test domain can degrade the performance of algorithms developed for the i.i.d setting. We propose a robust criterion termed Transfer Risk Minimization (TRM) to tackle the out-of-distribution problem. The transfer risk promotes the transferability of the per-environment predictors. The feature representation updates accordingly to support such transfer. We demonstrate that TRM better recovers the weights associated with the invariant features by an illustrative example. Due to the optimality of the per-environment predictor, TRM objective naturally decomposes into two terms, the direct transfer term and the weighted-gradient matching term. One limitation of TRM is that the inverse Hessian vector product can have a large variance with a small batch size. The better optimization of the weighted gradient-matching term is left for future work.

Experimentally, we test our approach on several controlled experiments. We show that TRM achieves better out-of-distribution performance under different combinations of features. We also demonstrate the effectiveness of TRM on the PACS and Office-Home datasets.

Acknowledgements

This research was supported by the HDTV Grand Alliance Fellowship.

References

Appendix A Proofs

a.1 Proof of Theorem 1

In the following theorem, we show a simplified version as in Rosenfeld2020TheRO, where . The full version can be similarly deduced.

Theorem 2 (Formal statement).

Assume . We further assume there exist two training environments such that and . Then there exists a classifier which achieves near optimal IRMv1 loss (Eq. (2)) and has high transfer risk (Eq. (1)) when . In addition, for any test environment with a non-causal mean far from the those in training:

for some . Then the classifier behaves like the ERM-trained classifier on fraction of the test distribution.

Proof.

We follow the proof idea of theorem 6.1 in Rosenfeld2020TheRO and first define . We construct as

where is the ball centered at . We construct the classifier as follows:

outputs the invariant feature in the balls centered at non-causal means except . Note that is the optimal predictor on environment when using the feature extractor above, hence it automatically zero gradient penalty on environment . By setting in theorem D.3 (Rosenfeld2020TheRO), the IRMv1 penalty term environment other than is upper bounded by

where . Thus when , the penalty term shrinks rapidly towards 0.

For , the classifier is the invariant classifier that only uses for prediction, and thus has small ERM loss: . For , the incurred loss can be upper bounded by

(7)
(8)
(9)

where are the PDF and CDF of standard Gaussian. When , the upper bound Eq. (9) is approximately zero. Eq. (7) holds by . Eq. (8) holds by the following lower bound of Gaussian CDF : (erfc). Together, the classifier has smalle ERM loss and IRMv1 penalty, hence it achieves nearly optimal IRMv1 loss when is large.

On the other hand, consider the transfer risk of the constructed classifier. The environmental optimal classifier on the environment is by the construction of . Since , we know that when ,, The incurred transfer risk is lower bounded by applying the predictor on environment :

Hence when is large, the transfer risk of the constructed classifier is closed to . In addition, by theorem D.3 (Rosenfeld2020TheRO), for some , the classifier behaves like the optimal ERM classifier in environment on