Out-of-Distribution Generalization Analysis via Influence Function

01/21/2021 ∙ by Haotian Ye, et al. ∙ HUAWEI Technologies Co., Ltd. Peking University 6

The mismatch between training and target data is one major challenge for current machine learning systems. When training data is collected from multiple domains and the target domains include all training domains and other new domains, we are facing an Out-of-Distribution (OOD) generalization problem that aims to find a model with the best OOD accuracy. One of the definitions of OOD accuracy is worst-domain accuracy. In general, the set of target domains is unknown, and the worst over target domains may be unseen when the number of observed domains is limited. In this paper, we show that the worst accuracy over the observed domains may dramatically fail to identify the OOD accuracy. To this end, we introduce Influence Function, a classical tool from robust statistics, into the OOD generalization problem and suggest the variance of influence function to monitor the stability of a model on training domains. We show that the accuracy on test domains and the proposed index together can help us discern whether OOD algorithms are needed and whether a model achieves good OOD generalization.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 3

page 18

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Most machine learning systems assume both training and test data are independently and identically distributed, which does not always hold in practice (Bengio et al. (2019)). Consequently, its performance is often greatly degraded when the test data is from a different domain (distribution). A classical example is the problem to identify cows and camels (Beery et al. (2018)), where the empirical risk minimization (ERM, Vapnik (1992)

) may classify images by background color instead of object shape. As a result, when the test domain is “out-of-distribution” (OOD), e.g. when the background color is changed, its performance will drop significantly. The OOD generalization is to obtain a robust predictor against this distribution shift.

Suppose that we have training data collected from domains:

(1)

where is the distribution corresponding to domain , is the set of all available domains, including validation domains, and is a data point. The OOD problem we considered is to find a model such that

(2)

where is the set of all target domains and is the expected loss of on the domain . Recent algorithms address this OOD problem by recovering invariant (causal) features and build the optimal model on top of these features, such as Invariant Risk Minimization (IRM, Arjovsky et al. (2019)), Risk Extrapolation (REx, Krueger et al. (2020)), Group Distributionally Robust Optimization (gDRO, Sagawa et al. (2019)) and Inter-domain Mixup (Mixup, Xu et al. (2020); Yan et al. (2020); Wang et al. (2020)

). Most works evaluate on Colored MNIST (see

5.1 for details) where we can directly obtain the worst domain accuracy over . Gulrajani and Lopez-Paz (2020) has assembled many algorithms and multi-domain datasets, and finds that OOD algorithms can’t outperform ERM in some domain generalization tasks (Gulrajani and Lopez-Paz (2020)), e.g. VLCS (Torralba and Efros (2011)) and PACS (Li et al. (2017)). This is not surprising, since these tasks only require high performance on certain domains, while an OOD algorithm is expected to learn truly invariant features and be excellent on a large set of target domains . This phenomenon is described as “accuracy-vs-invariance trade-off” in Akuzawa et al. (2019).

Two questions arise in the min-max problem (2). First, previous works assume that there is sufficient diversity among the domains in Thus the supremacy of may be much larger than the average, which implies that ERM may fail to discover But in reality, we do not know whether it is true. If not, the distribution of is concentrated on the expectation of , and ERM is sufficient to find an invariant model for Therefore, we call for a method to judge whether an OOD algorithm is needed. Second, how to judge a model’s OOD performance? Traditionally, we consider test domains and use the worst-domain accuracy over

(which we call test accuracy) to approximate the OOD accuracy. However, test accuracy is a biased estimate of the OOD accuracy unless

is closed to . More seriously, It may be irrelevant or even negatively correlated to the OOD accuracy. This phenomenon is not uncommon, especially when there are features virtually spurious in but show a strong correlation to the target in .

We give a toy example in Colored MNIST when the test accuracy fails to approximate the OOD accuracy. For more details, please refer to Section 5.1 and Appendix A.4. We choose three domains from Colored MNIST and use cross-validation (Gulrajani and Lopez-Paz (2020)) to select models, i.e. we take turns to select a domain as the test domain and train on the rest, and select the model with max average test accuracy. Figure 1 shows the comparison between ERM and IRM. One can find that no matter which domain is the test domain, ERM model uniformly outperforms IRM model on the test domain. However, IRM model achieves consistently better OOD accuracy. Shortcomings of the test accuracy here are obvious, regardless of whether cross-validation is used. In short, the naive use of the test accuracy may result in a non-OOD model.

Figure 1: Experiments in Colored MNIST to show blacktest accuracy is not enough to reflect a model’s OOD accuracy. The top left penal shows the test accuracy of ERM and IRM. The other three panels present the relationship between test accuracy (x-axis) and OOD accuracy (y-axis) in three setups.

To address this obstacle, we hope to find a metric that correlates better with model’s OOD property, even when is much smaller than and the “worst” domain remains unknown. Without any assumption to , our goal is unrealistic. Therefore, we assume that features that are invariant across should also be across . This assumption is necessary. Otherwise, the only thing we can do is to collect more domains. Therefore, we need to focus on what features the model has learnt. Specifically, we want to check whether the model learns invariant features and avoid varying features.

The influence function (Cook and Weisberg (1980)) can serve our purpose. Influence function was proposed to measures the parameter change when a data point is removed or upweighted by a small perturbation (details in 3.2). When modified it to domain-level, it measures the influence of a domain instead of a data point on the model. Note that we are not emulating the changes of the parameter when a domain is removed. Instead, we are exactly caring about upweighting the domain by (will be specified later). Base on this, the variance of influence function allows us to measure OOD property and solve the obstacle.

Contributions

we summarize our contributions here: (i) We introduce influence function to domain-level and propose index (formula 6) based on influence function of the model . Our index can measure the OOD extent of available domains, i.e. how different these domains (distributions) are. This measurement provides a basis for whether to adopt an OOD algorithm and to collect more diverse domains. See Section 4.1 and Section 5.1.1 for details. (ii) We point out that the proposed index can solve the weakness of test accuracy. Specifically, under most OOD generalization problems, using test accuracy and our index together, we can discern the OOD property of a model. See Section 4.2 for details. (iii) We propose to use only a small but important part of the model to calculate the influence function. This overcomes the huge computation cost of solving the inverse of Hessian. It is not merely for calculation efficiency and accuracy, but it coincides with our understanding that only these parameters capture what features a model has learnt (Section 4.3).

We organize our paper as follows: Section 2 reviews related works and Section 3 introduces the preliminaries of OOD methods and influence function. Section 4 presents our proposal and detailed analysis. Section 5 shows our experiments. The conclusion is given in Section 6.

2 Related work

The mismatch between the development dataset and the target domain is one major challenge in machine learning (Castro et al. (2020); Kuang et al. (2020)). Many works assume that the ground truth can be represented by a causal Direct Acyclic Graph (DAG), and blackthey use the DAG structure to discuss the worst-domain performance (Rojas-Carulla et al. (2018); Peters et al. (2016); Subbaswamy et al. (2019); Bühlmann and others (2020); Magliacane et al. (2018)). All these works employ multiple domain data and causal assumptions to discover the parents of the target variable. Rojas-Carulla et al. (2018) and Magliacane et al. (2018) also apply this idea to Domain Generalization and Multi-Task Learning setting. Starting from multiple domain data rather than model assumptions, Arjovsky et al. (2019) proposes Invariant Risk Minimization (IRM) to extract causal (invariant) features and learn invariant optimal predictor on the top of the causal features. It analyzes the generalization properties of IRM from the view of sufficient dimension reduction (Cook (2009); Cook et al. (2002)). Ahuja et al. (2020) considers IRM as finding the Nash equilibrium of an ensemble game among several domains and develops a simple training algorithm. Krueger et al. (2020) derives the Risk Extrapolation (REx) to extract invariant features and further derives a practical objective function via variance penalization. Xie et al. (2020) employs a framework from distributional robustness to interpret the benefit of REx comparing to robust optimization (Ben-Tal et al. (2009); Bagnell (2005)). Besides, Adversarial Domain Adaption (Li et al. (2018); Koyama and Yamaguchi (2020)) uses discriminator to look for features that are independent of domains and uses these features for further prediction.

Influence function is a classic method from the robust statistics literature (Robins et al. (2008, 2017); Van der Laan et al. (2003); Tsiatis (2007)). It can be used to track the impact of a training sample on the prediction. Koh and Liang (2017) proposes a second-order optimization technique to approximate the influence function. They verify their method with different assumptions on the empirical risk ranging from being strictly convex and twice-differentiable to non-convex and non-differentiable losses. Koh et al. (2019) also estimates the effect of removing a subgroup of training points via influence function. They find out that the approximation computed by the influence function is correlated with the actual effect. Influence function has been used in many machine learning tasks. Cheng et al. (2019) proposes an explanation method, Fast Influence Analysis, that employs influence function on Latent Factor Model to solve the lack of interpretability of the collaborative filtering approaches for recommender systems. Cohen et al. (2020) uses influence function to detect adversarial attacks. Ting and Brochu (2018) proposes an asymptotically optimal sampling method via an asymptotically linear estimator and the associated influence function. Alaa and Van Der Schaar (2019) develops a model validation procedure that estimates the estimation error of causal inference methods. Besides, Fang et al. (2020) leverages influence function to select a subset of normal users who are influential to the recommendations.

3 Preliminaries

3.1 ERM, IRM and REx

In this section, we give some notations and introduce blacksome recent OOD methods. Recall the multiple domain setup (1) and OOD problem (2). For a domain and a hypothetical model , the population loss is where

is the loss function on

. The empirical loss, which is the objective of ERM, is with

Recent OOD methods propose some novel regularized objective functions in the form:

(3)

to discover in (2). Here is a regularization term and is the tuning parameter which controls the degree of penalty. Note that ERM is a special case by setting . For simplicity, we will use to represent the total loss in case of no ambiguity. Arjovsky et al. (2019) focuses on the stability of and considers the IRM regularization:

(4)

blackwhere is a scalar and fixed “dummy” classifier. Arjovsky et al. (2019) shows that the scalar fixed classifier is sufficient to monitor invariance and responds to the idealistic IRM problem which decomposes the entire predictor into data representation and one shared optimal top classifier for all training domains. On the other hand, Krueger et al. (2020) encourages the uniform performance of and proposes the V-REx penalty:

blackKrueger et al. (2020) derives the invariant prediction by the robustness to spurious features and figure out that REx is more robust than group distributional robustness (Sagawa et al. (2019)). In this work, we also decompose the entire predictor into a feature extractor and a classifier on the top of the learnt features. As we will see, different from Arjovsky et al. (2019) and Krueger et al. (2020), we directly monitor the invariance of the top model.

3.2 Influence function and group effect

Consider a parametric hypothesis and the corresponding solution: By a quadratic approximation of around , the influence function takes the form

When the sample size of is sufficiently large, the parameter change due to removing a data point can be approximated by without retraining the model. Here stands for the cardinal of the set . Furthermore, Koh et al. (2019) shows that the influence function can also predict the effects of large groups of training points black(i.e. ), although there are significant changes in the model. The parameter change due to removing the group can be approximated by black

Motivated by the work of Koh et al. (2019), we introduce influence function to OOD problem to address our obstacles.

4 Methodology

4.1 Influence of domains

We decompose a parametric hypothesis into a top model and a feature extractor , i.e. and Such decomposition coincides the understanding of most DNN, i.e. a DNN extracts the features and build a top model based on the extracted features. When upweighting a domain by a small perturbation , we do not upweight the regularized term, i.e.

since the stability across different domains, which is encouraged by the regularization, should not depend on the sample size of a domain. black For a learnt model , fixing the feature extractor , i.e. fixing , the change of top model caused by upweighting the domain is

(5)

Here , and we assume is twice-differentiable in . blackPlease see Appendix A.3 for detailed derivation and why should be fixed. For a regularized method, e.g. IRM and REx, the influence of their regularized term is reflected in and in learnt model . As mentioned above, measures change of model caused by upweighting domain . Therefore, if is invariant across domains, the entire model treats all domains equally. As a result, a small perturbation on different domains should cause the same model change. This leads to our proposal.

4.2 Proposed Index blackand its Utility

On basis of the domain-level influence function , we propose our index to measure the fluctuation of the parameter change when different domains are upweighted:

(6)

Here

is the 2-norm for matrix, blacki.e. the largest eigenvalue of the matrix,

refers to blackthe covariance matrix of the domain-level influence function over and is a nonlinear transformation that works well in practice.

OOD Model

Under the OOD problem in (2), a good OOD model should (i) learn invariant and useful features; (ii) avoid spurious and varying features. Learning useful and invariant features means the model should have blackhigh accuracy over a set of test domains , no matter which test domain it is. In turn, high accuracy blackover also means the model truly learns some useful features blackfor the test domains. However, this is not enough, since we blackdo not know whether the useful features are blackinvariant features across or just spurious features on On the other hand, avoiding varying features means that different domains are actually the same to the blacklearnt model, so according to the arguments in Section 4.1, should be small. Combined this, we derive our proposal: if a learnt model manage to simultaneously achieve small and high accuracy over , it should have good OOD accuracy. We prove our proposal in a simple but illuminating case, and we conduct various experiments (Section 5) to support our proposal. Several issues should be clarified. First, not all OOD problems demand models to learn invariant features. blackFor example, the set of all target domains is small such that the varying features are always strongly correlated to the labels, or the objective is the mean of the accuracy over rather than the worst-domain accuracy. But to our concern, we regard the OOD problem blackin (2) as a bridge to causal discover. Thus the set of the target domains is large, and the “weak” OOD problems are out of our consideration. To a large extent, invariant features are still the major target and our proposal is still a good criterion to model’s OOD property. Second, we admit that the gap between being stable in (small ) and avoiding all spurious features on does exist. However, to our knowledge, for features that are varying in but are invariant in , demanding a model to avoid them is somehow unrealistic. Therefore, we make a step forward that we measure whether the learnt model successfully avoids features that vary across . We leave index about varying features over in our future work.

The Shuffle

As mentioned above, smaller metric means strong stablility across , and hence should have better OOD accuracy. However, the proposed metric depends on the dataset and the learnt model Therefore, there is no uniform baseline to check whether the metric is “small” enough. To this end, we propose a baseline value of the proposed metric by shuffling the multi-domain data. Consider pooling all data points in and randomly redistributed to new synthetic domains . blackWe compute the shuffle version of for a learnt model over the shuffled data :

(7)

and denote the standard version and shuffle version of the metric as and respectively. For any algorithm that obtains relatively good test accuracy, if is much larger than , has learnt features that vary across , and cannot treat domains in equally. This implies that may not be an invariant predictor over blackOtherwise, if the two values are similar, the model has avoided varying features in and maybe invariant across . Therefore, either the model capture the invariance over the diverse domains, or the domains are not diverse at all. Note that this process is suitable for any algorithm, hence providing a baseline to see whether is small. Here we also obtain a method to judge whether an OOD algorithm is needed. Consider learnt by ERM. If is relatively larger than , then ERM fails to avoid varying features. In this case, one should consider an OOD algorithm to achieve better OOD generalization. Otherwise, ERM is enough, and any attempt to achieve better OOD accuracy should start with finding more domains instead of using OOD algorithms. This coincides experiments in Gulrajani and Lopez-Paz (2020) (Section5.2). Our understanding is that domains in are similar. Therefore, the difference between shuffle and standard version of the metric reflects how much varying features a learnt model uses. We show how to use the two version of in Section 5.1.1 and Section 5.2.

4.3 Influence Calculation

There is a question surrounding the influence function: how to efficiently calculate and inverse Hessian? Koh and Liang (2017) suggests Conjugate Gradient and Stochastic estimation solve the problem. However, when is obtained by running SGD, it could hardly arrive at the global minimum. Although adding a damping term (i.e. let

) can moderately alleviate the problem by transforming it into a convex situation, under large neural-network with non-linear activation function like ReLU, this method may still work poorly since the damping term in order to satisfy the transform is so large that it will influence the performance significantly. Most importantly, the variation of the eigenvalue of Hessian is huge, making the convergence of influence function calculation quite slow and inaccurate (

Basu et al. (2020)).

In our metric, we circumvent the problem by excluding most parameters and directly calculate Hessian of to get accurate influence function. This modification not only speed up the calculation, but it also coincides our expectation, that an OOD algorithm should learn invariant features does not mean that the influence function of all parameters should be identical across domains. For example, if wants to extract the same features in different domains, the influence function should be different on . Therefore, if we use all parameters to calculate the influence, given that is relatively insignificant in size compared with , the information of learnt features provided by is hard to be captured. On the contrary, only considering the influence of the top model will manifest the influence of different domains in the aspect of features, thus enabling us to achieve our goal.

As our experiments show, after blackthis modification, the influence function calculation speed can be 2000 times faster, and the utility (correlation with OOD property) could be even higher. One may not feel surprised given the huge number of parameters in the embedding model . They slow down the calculation and overshadow the top model’s influence value.

5 Experiment

In this section, we experimentally show that: (1) A model reaches small if it has good OOD property, while a non-OOD model won’t. (2) The metric provides additional information on the stability of a learnt model, which overcomes the weakness of the test accuracy. black(3) The comparison of and can check whether a better OOD algorithm is needed.

We consider experiments in Bayesian Network, Colored MNIST blackand VLCS. The synthetic data generated by Bayesian Network includes domain-dependent noise and fake associations between features and response. For Colored MNIST, we already know that the digit is the causal feature and the color is non-causal. The causal relationships help us to determine the worst domain and obtain the OOD accuracy. blackVLCS is a real dataset, in which we show utility of

step by step. Due to the space limitation, we put the experiments in Bayesian Network to the appendix.

Generally, cross-validation (Gulrajani and Lopez-Paz (2020)) is used to judge a model’s OOD property. In the introduction, we have already shown that the leave-one-domain-out cross-validation may fail to discern OOD properties. We also consider another two potential competitors: conditional mutual information and IRM penalty. The comparison between our metric and the two competitors are postponed into Appendix.

Figure 2: The index is highly correlated to . The plot contains 501 learnt ERM models with ,

The dashed line is the baseline value when the difference between domains is eliminated by pooling and redistributing the training data. The blue solid line is the linear regression of

versus .

5.1 Colored MNIST

Colored MNIST (Arjovsky et al. (2019)) introduces a synthetic binary classification task. The images are colored according to their label, making color a spurious feature in predicting the label. Specifically, for a domain , we assign a preliminary binary label and randomly flip with . Then, we color the image according to but with a flip rate of . Clearly, when or , color is more correlated with than real digit. Therefore, the oracle OOD model will attain accuracy in all domains while an ERM model may attain high training accuracy and low OOD property if in training domains is too small or too large. Throughout the Colored MNIST experiments, we use three-layer MLP with ReLU activation and hidden dimension 256. Although our MLP model has relatively many parameters and is non-convex due to the activation layer, due to the technique mentioned in Section 4.3, the influence calculation is still fast and accurate, with directly calculating influence once spends less than 2 seconds.

5.1.1 Identify OOD Problem

In this section, we show that can discern whether blackthe training domains are sufficiently diverse as mentioned in Section 4.2. Assume has five training domains with

blackwhere is positively related to the diversity among the training domains. If is zero, all data points are generated from the same domain () and so blackthe learning task on is not an OOD problem. On the contrary, blacklarger means that the training domains are more diverse. We repeat 501 times to learn the model with ERM. Given the learnt model and the training data, we compute and check the correlation between and . Figure 2 presents the results. Our index is highly related to The Pearson coefficient is 0.9869, and the Spearman coefficient is 0.9873. Also, the benchmark of that learns on the same training domains ( in 4.2) can be derived from the raw data by pooling and redistributing all data points, and we mark it by the black dashed line. If is much higher than the benchmark, indicating that is not small, an OOD algorithm should be considered if better OOD generalization is demanded. Otherwise, the present algorithm (like ERM) is sufficient. The results coincide our expectation that can discern whether is different.

5.1.2 Relationship between and OOD Accuracy

In this section, we use an experiment to support our proposal in Section 4.2. As previously proposed, if a model shows high test accuracy and small simultaneously, it captures invariant features and avoids varying features, so it deserves to be an OOD model. In this experiment, we consider a model with high test accuracy and show that smaller generally corresponds to better OOD accuracy, which supports our proposal.

Figure 3: The relationship between and OOD accuracy in REx (left) and IRM (right) with We train 400 models for each . The OOD accuracy and enjoy high Pearson coefficient: -0.9745 (up-left), -0.9761 (down-left), -0.8417 (up-right), -0.9476 (down-right). The coefficients are negative because lower forebodes better OOD property.

Consider two setups: and We implement IRM and REx with different penalty (note that ERM is ) to check relationship between and OOD accuracy. For IRM and REx, we run epochs pre-training with and use early stopping to prevent over-fitting. With this technique, all models successfully achieve good test accuracy (within 0.1 of the oracle accuracy) and meet our requirement. Figure 3 presents the results. We can see that are highly correlated to OOD accuracy in IRM and REx, with the absolute of Pearson Coefficient never less than . Those models learned with larger present better OOD property, learning less varying features, and showing smaller The results are consistent with our proposal, except that when is large in IRM,

is a little bit unstable. We have carefully examined the phenomenon and found that it is caused by computational instability when inversing Hessian with eigenvalue quite close to 0. The problem of unstable inversing happens with a low probability and can be addressed by repeating the experiment once or twice.

5.2 Domain Generalization: VLCS

In this section, we implement the proposed metric for 4 algorithms: ERM, gDRO, Mixup and IRM on the VLCS image dataset, which is widely used for doamin generalization. We emulate a real scenario with and . As mentioned in Gulrajani and Lopez-Paz (2020), we use “training-domain validation set” method, i.e. we split a validation set for each and the test accuracy is defined as the average accuracy amount the three validation sets. Note that, our goal is to use the test accuracy and to measure the OOD generalization, rather than to tune for the SOTA performance on a unseen domain . Therefore, we do not apply any model selection method and just use the default hyper-parameters in Gulrajani and Lopez-Paz (2020).

5.2.1 Step 1: Test accuracy comparison

Domain C L V Mean
ERM 99.29 73.62 77.07 83.34
Mixup 99.32 74.36 78.84 84.17
gDRO 95.79 70.95 75.25 80.66
IRM 49.44 44.76 41.17 45.12
Table 1: Step1: Test Accuracy ()

For each algorithm, we run the naive training process 12 times and show the average of test accuracy of each algorithm in Table 1. Before calculating , the learnt model should at least arrive a good test accuracy. Otherwise, there is no need to discuss its OOD performance since OOD accuracy is smaller than test accuracy. In the table, the test accuracy of ERM, Mixup and gDRO is good, but that of IRM is not. In this case, IRM will be eliminated. If an algorithm fails to reach high test accuracy first, we should first change the hyper-parameters until we observe a relatively high test accuracy.

5.2.2 Step 2: shuffle and standard metric comparison

Now we are ready to check whether the learnt models are invariant across . As mentioned in 4.2, the difference of and represents whether how much a model is invariant across . We calculate the value and the results are in Figure 4. For ERM and Mixup, the two value is nearly the same. In this case, we expect that ERM and Mixup models are invariant and should have a relatively high OOD accuracy, so no more algorithm is needed. For gDRO, we can clearly see that is uniformly smaller than . Therefore, gDRO models don’t treat different domains equally, and hence we predict that the OOD accuracy will be relatively low. In this case, one who starts with gDRO should turn to other algorithms if a better OOD performance is demanded.

Note that, in the whole process, we know nothing about , so the OOD accuracy is unseen. However, from the above analysis, we know that (1) in this settings, ERM and Mixup is better than gDRO; (2) one who uses gDRO can turn to other algorithms (like Mixup) for better OOD performance; (3) one who uses ERM should consider collecting more environments if he (she) still wants to improve OOD performance. So far, we finish the judgement using test accuracy and the proposed metric.

Figure 4: The standard and shuffle version of the metric, i.e. and for ERM, Mixup and gDRO. For each algorithm, each version of the metric, we run the experiments more than 12 times in case of statistical error. Similar and represents invariance across , which is the case of ERM and Mixup. For gDRO, is clearly smaller.

5.2.3 Step 3: OOD accuracy results (oracle)

ERM Mixup gDRO IRM
Mean 62.76 63.91 60.17 31.33
Std 1.16 1.57 2.56 13.44
Table 2: Step3: OOD Accuracy ()

In this step, we fortunately obatin and can check whether our judgement is reasonable. Normally, this step will not happen. We now show the OOD accuracy of four algorithms in table 2. Similar to our judgement, ERM and Mixup models achieve a higher OOD accuracy than gDRO. The performance of IRM (under this hyper-parameters) is lower than test accuracy. During the above process, we can also compare the metric of the model from the same algorithm but with different hyper-parameters (as the same in section 5.1.2). Besides, one may notice that even the highest OOD accuracy is just . That is to say, to obtain OOD accuracy larger than , we should consider collecting more environments. In the appendix A.6, we continue our real scenario to see that, if initially is more diverse, what will our metric lead us to.

The whole results in VLCS can also be found in the same appendix, and the comparison of the proposed metric with the IRM penalty in formula 4 can be found there too. Besides, we show the comparison with Conditional Mutual Information in the appendix A.5. In summary, we use a realistic task to see how to judge the OOD property of learnt model using the proposed metric and test accuracy. The judgement coincides well with the real OOD performance.

6 Conclusion

In this paper, we focus on two presently unsolved problems, that how can we discern the OOD property of multiple domains and of learnt models. To this end, we introduce influence function into OOD problem and propose our metric to help solve these issues. Our metric can not only discern whether a multi-domains problem is OOD but can also judge a model’s OOD property when combined with test accuracy. To make our calculation more meaningful, accurate and efficient, we modify influence function to domain-level and propose to use only the top model to calculate the influence. Our method is proved in simple cases and it works well in experiments. We sincerely hope that, with the help of this index, our understanding of OOD generalization will become more and more precise and thorough.

References

  • K. Ahuja, K. Shanmugam, K. Varshney, and A. Dhurandhar (2020) Invariant risk minimization games. arXiv preprint arXiv:2002.04692. Cited by: §2.
  • K. Akuzawa, Y. Iwasawa, and Y. Matsuo (2019) Domain generalization via invariant representation under domain-class dependency. Cited by: §1.
  • A. Alaa and M. Van Der Schaar (2019) Validating causal inference models via influence functions. K. Chaudhuri and R. Salakhutdinov (Eds.), Proceedings of Machine Learning Research, Vol. 97, Long Beach, California, USA, pp. 191–201. Cited by: §2.
  • M. Arjovsky, L. Bottou, I. Gulrajani, and D. Lopez-Paz (2019) Invariant risk minimization. arXiv preprint arXiv:1907.02893. Cited by: §A.2, §A.2, §1, §2, §3.1, §5.1.
  • J. A. Bagnell (2005)

    Robust supervised learning

    .
    In

    Proceedings of the national conference on artificial intelligence

    ,
    Vol. 20, pp. 714. Cited by: §2.
  • S. Basu, P. Pope, and S. Feizi (2020)

    Influence functions in deep learning are fragile

    .
    arXiv preprint arXiv:2006.14651. Cited by: §4.3.
  • S. Beery, G. Van Horn, and P. Perona (2018) Recognition in terra incognita. In

    Proceedings of the European Conference on Computer Vision (ECCV)

    ,
    pp. 456–473. Cited by: §1.
  • A. Ben-Tal, L. El Ghaoui, and A. Nemirovski (2009) Robust optimization. Vol. 28, Princeton University Press. Cited by: §2.
  • Y. Bengio, T. Deleu, N. Rahaman, N. R. Ke, S. Lachapelle, O. Bilaniuk, A. Goyal, and C. Pal (2019) A meta-transfer objective for learning to disentangle causal mechanisms. In International Conference on Learning Representations, Cited by: §1.
  • P. Bühlmann et al. (2020) Invariance, causality and robustness. Statistical Science 35 (3), pp. 404–426. Cited by: §2.
  • D. C. Castro, I. Walker, and B. Glocker (2020) Causality matters in medical imaging. Nature Communications 11 (1), pp. 1–10. Cited by: §2.
  • W. Cheng, Y. Shen, L. Huang, and Y. Zhu (2019) Incorporating interpretability into latent factor models via fast influence analysis. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 885–893. Cited by: §2.
  • G. Cohen, G. Sapiro, and R. Giryes (2020) Detecting adversarial samples using influence functions and nearest neighbors. In

    Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)

    ,
    Cited by: §2.
  • R. D. Cook, B. Li, et al. (2002) Dimension reduction for conditional mean in regression. The Annals of Statistics 30 (2), pp. 455–474. Cited by: §2.
  • R. D. Cook and S. Weisberg (1980) Characterizations of an empirical influence function for detecting influential cases in regression. Technometrics 22 (4), pp. 495–508. Cited by: §1.
  • R. D. Cook (2009) Regression graphics: ideas for studying regressions through graphics. Vol. 482, John Wiley & Sons. Cited by: §2.
  • M. Fang, N. Z. Gong, and J. Liu (2020) Influence function based data poisoning attacks to top-n recommender systems. In Proceedings of The Web Conference 2020, pp. 3019–3025. Cited by: §2.
  • I. Gulrajani and D. Lopez-Paz (2020) In search of lost domain generalization. arXiv preprint arXiv:2007.01434. Cited by: §A.6.2, §A.6.2, Table 5, §1, §1, §4.2, §5.2, §5.
  • P. W. Koh and P. Liang (2017) Understanding black-box predictions via influence functions. arXiv preprint arXiv:1703.04730. Cited by: §A.3, §2, §4.3.
  • P. W. W. Koh, K. Ang, H. Teo, and P. S. Liang (2019) On the accuracy of influence functions for measuring group effects. In Advances in Neural Information Processing Systems, pp. 5254–5264. Cited by: §2, §3.2.
  • M. Koyama and S. Yamaguchi (2020) Out-of-distribution generalization with maximal invariant predictor. arXiv preprint arXiv:2008.01883. Cited by: §2.
  • D. Krueger, E. Caballero, J. Jacobsen, A. Zhang, J. Binas, R. L. Priol, and A. Courville (2020) Out-of-distribution generalization via risk extrapolation (rex). arXiv preprint arXiv:2003.00688. Cited by: §1, §2, §3.1.
  • K. Kuang, R. Xiong, P. Cui, S. Athey, and B. Li (2020) Stable prediction with model misspecification and agnostic distribution shift.. In AAAI, pp. 4485–4492. Cited by: §2.
  • D. Li, Y. Yang, Y. Song, and T. M. Hospedales (2017) Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision, pp. 5542–5550. Cited by: §1.
  • H. Li, S. Jialin Pan, S. Wang, and A. C. Kot (2018) Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5400–5409. Cited by: §2.
  • S. Magliacane, T. van Ommen, T. Claassen, S. Bongers, P. Versteeg, and J. M. Mooij (2018) Domain adaptation by using causal inference to predict invariant conditional distributions. In Advances in Neural Information Processing Systems, pp. 10846–10856. Cited by: §2.
  • S. Mukherjee, H. Asnani, and S. Kannan (2020) CCMI: classifier based conditional mutual information estimation. In Uncertainty in Artificial Intelligence, pp. 1083–1093. Cited by: §A.5.
  • J. Peters, P. Bühlmann, and N. Meinshausen (2016)

    Causal inference by using invariant prediction: identification and confidence intervals

    .
    Journal of the Royal Statistical Society: Series B (Statistical Methodology) 5 (78), pp. 947–1012. Cited by: §2.
  • J. Robins, L. Li, E. Tchetgen, A. van der Vaart, et al. (2008) Higher order influence functions and minimax estimation of nonlinear functionals. In Probability and statistics: essays in honor of David A. Freedman, pp. 335–421. Cited by: §2.
  • J. M. Robins, L. Li, R. Mukherjee, E. T. Tchetgen, A. van der Vaart, et al. (2017) Minimax estimation of a functional on a structured high-dimensional model. The Annals of Statistics 45 (5), pp. 1951–1987. Cited by: §2.
  • M. Rojas-Carulla, B. Schölkopf, R. Turner, and J. Peters (2018)

    Invariant models for causal transfer learning

    .
    The Journal of Machine Learning Research 19 (1), pp. 1309–1342. Cited by: §2.
  • S. Sagawa, P. W. Koh, T. B. Hashimoto, and P. Liang (2019) Distributionally robust neural networks. In International Conference on Learning Representations, Cited by: §1, §3.1.
  • R. Sen, A. T. Suresh, K. Shanmugam, A. G. Dimakis, and S. Shakkottai (2017) Model-powered conditional independence test. In Advances in neural information processing systems, pp. 2951–2961. Cited by: Figure 6, §A.5.
  • A. Subbaswamy, P. Schulam, and S. Saria (2019) Preventing failures due to dataset shift: learning predictive models that transport. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 3118–3127. Cited by: §2.
  • D. Ting and E. Brochu (2018) Optimal subsampling with influence functions. In Advances in Neural Information Processing Systems 31, S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett (Eds.), pp. 3650–3659. Cited by: §2.
  • A. Torralba and A. A. Efros (2011) Unbiased look at dataset bias. In CVPR 2011, pp. 1521–1528. Cited by: §1.
  • A. Tsiatis (2007) Semiparametric theory and missing data. Springer Science & Business Media. Cited by: §2.
  • M. J. Van der Laan, M. Laan, and J. M. Robins (2003) Unified methods for censored longitudinal data and causality. Springer Science & Business Media. Cited by: §2.
  • A. W. Van der Vaart (2000) Asymptotic statistics. Vol. 3, Cambridge university press. Cited by: §A.3.
  • V. Vapnik (1992) Principles of risk minimization for learning theory. In Advances in neural information processing systems, pp. 831–838. Cited by: §1.
  • Y. Wang, H. Li, and A. C. Kot (2020) Heterogeneous domain generalization via domain mixup. In ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 3622–3626. Cited by: §1.
  • S. Wright (1921) Correlation and causation. J. agric. Res. 20, pp. 557–580. Cited by: §A.2.
  • C. Xie, F. Chen, Y. Liu, and Z. Li (2020) Risk variance penalization: from distributional robustness to causality. arXiv preprint arXiv:2006.07544. Cited by: §2.
  • M. Xu, J. Zhang, B. Ni, T. Li, C. Wang, Q. Tian, and W. Zhang (2020) Adversarial domain adaptation with domain mixup. In Proceedings of the AAAI Conference on Artificial Intelligence, Vol. 34, pp. 6502–6509. Cited by: §1.
  • S. Yan, H. Song, N. Li, L. Zou, and L. Ren (2020) Improve unsupervised domain adaptation with mixup training. arXiv preprint arXiv:2001.00677. Cited by: §1.

Appendix A Appendix

a.1 Simple Bayesian Network

In this section, we show that the model with better OOD accuracy achieves smaller . We assume the data is generated from the following Bayesian network:

(8)

where are the features,

is the target vector,

and are the underlying parameters that are invariant across domains. The variance of gaussian noise is that depends on domain. For simplicity, we denote to represent a domain. The goal here is to linearly regress the response y on the input vector , i.e. According to the Bayesian network (8), is the invariant feature, while the correlation between and is spurious and unstable since varies across domains. Clearly, the model based only on is an invariant model. Any invariant estimator should blackachieve and .

Method Causal Error Non-causal Error
ERM 15.844 0.582 0.581
IRM 5.254 0.122 0.109
REx 1.341 0.042 0.033
Table 3: Average parameter error and the stable measurement of 500 models from ERM, IRM and REx. Here, “Causal Error” represents and “Non-causal Error” represents .

Now consider five training domains , each containing 1000 data points. We estimate three linear models blackusing ERM, IRM and REx respectively and record the parameter error as well as (note that is here). Table 3 presents the results among 500 repetitions. As expected, IRM and REx learn more invariant relationships than ERM (smaller causal error) and better avoid non-causal variables (). Furthermore, the proposed measurement is highly related to invariance, i.e. model with better OOD property achieves smaller . This results coincides our understanding.

a.2 Proof of an Example

In this section, we use a simple model to illuminate the validity of proposed in Section 4. Consider a structural equation model (Wright (1921)):

where

is a distribution with a finite second-order moment, i.e.

, and is the variance of the noise term in Both and vary across domains. For simplicity, we assume there are infinite training data points collected from two training domains . Our goal is to predict y from using a least-squares predictor . Here we consider two algorithms: ERM and IRM with . According to Arjovsky et al. (2019), using IRM we obtain .

Intuitively, ERM will exploit both and , thus achieving a better regression model. However, since relationship between y and varies across domains, our index will be huge in such condition. Conversely, only uses invariant features , thus . Note that we do not have an embedding model here, so .

ERM we denote

Note that in , is sample from . We then have

To proceed further, we denote

By solving the following equations:

and

we have with

Now we calculate our index. It is easy to see that

Therefore,

(9)

On the other hand, calculate the hessian and we have

Then we have (note that )

where the third equation holds because the rank of matrix is . Clearly, when (means two domains become identical), our index . Otherwise, given , we have , showing that ERM captures varied features.

IRM We now turn to IRM model and show that when , thus proving IRM learnt model does achieve smaller compared with in ERM.

Under IRM model, assuming the tuning parameter is , we have

Then we have the gradient with respect :

and the Hessian matrix

(10)

Denote the solution of IRM algorithm on when penalty is . From Arjovsky et al. (2019) we know . To show , we only need to show that

We prove this by showing that

(11)

simultaneously. We add after to show that is a continuous function of . Rewrite in formula 10 as

where

Obviously, is positive definite. Therefore, we have

The first equation holds because has the limit and is not , and the last equation holds because the eigenvalue of goes to when .

Now consider . According to formula 9, we have

Hence we finish proof of formula 11 and show that in IRM.

a.3 Formula (5)

This section shows the derivation of the expression (5). Recall that the training dataset and the objective function

where the second term on the right hand side is the regularization. As to ERM, the regularization term is zero. With the feature extractor () fixed, we upweight a domain . The new objective function is

Notice that when upweight an domain, we only upweight the empirical loss on the corresponding domain. Further, we denote