irmempiricalstudy
An Empirical Study of Invariant Risk Minimization
view repo
Invariant risk minimization (IRM; Arjovsky et al., 2019) is a recently proposed framework designed for learning predictors that are invariant to spurious correlations across different training environments. Because IRM does not assume that the test data is identically distributed as the training data, it can allow models to learn invariances that generalize well on unseen and outofdistribution (OOD) samples. Yet, despite this theoretical justification, IRM has not been extensively tested across various settings. In an attempt to gain a better understanding of IRM, we empirically investigate several research questions using IRMv1, which is the first practical algorithm proposed in (Arjovsky et al., 2019) to approximately solve IRM. By extending the ColoredMNIST experiment from (Arjovsky et al., 2019) in multiple ways, we find that IRMv1 (i) performs better as the spurious correlation varies more widely between training environments, (ii) learns an approximately invariant predictor when the underlying relationship is approximately invariant, and (iii) can be extended to multiple environments, multiple outcomes, and different modalities (i.e., text). We hope that this work will shed light on the characteristics of IRM and help with applying IRM to realworld OOD generalization tasks.
READ FULL TEXT VIEW PDF
The standard risk minimization paradigm of machine learning is brittle w...
read it
Invariant Causal Prediction (Peters et al., 2016) is a technique for
out...
read it
We show that the Invariant Risk Minimization (IRM) formulation of Arjovs...
read it
This work considers the outofdistribution (OOD) prediction problem whe...
read it
Selective rationalization improves neural network interpretability by
id...
read it
Many real prediction tasks such as molecular property prediction require...
read it
The hippocampus is often attributed to episodic memory formation and sto...
read it
An Empirical Study of Invariant Risk Minimization
Invariant risk minimization (IRM) [2]
is a recently proposed machine learning framework where the goal is to learn invariances across multiple training environments
[10]. Compared to the widely used framework of empirical risk minimization (ERM), IRM does not assume that training samples are identically distributed. Rather, IRM assumes that training samples come from multiple environments and seeks to find associations that are invariant across those environments. This allows its resulting predictor to be effective in outofdistribution (OOD) generalization, i.e., achieving low error on test samples that might come from an unseen environment.Although IRM is a promising framework for OOD generalization, it is not extensively tested across many settings in which IRM is expected to perform well. Experiments in [2] are limited to twoenvironment binary classification tasks, which do not cover different multienvironment settings found in realworld datasets [12, 3, 5]. We believe that the lack of empirical validations makes it difficult to apply IRM to realworld tasks that require OOD generalization.
In this paper, we conduct a series of experiments that examine the extent to which the IRM framework can be effective. Specifically, we extend the ColoredMNIST setup from [2] in several ways and compare the OOD generalization performances of ERM and IRMv1, which is the first practical algorithm for IRM proposed by [2]. We find that:
The generalization performance of IRMv1 improves as the difference between training environments, in terms of the degree of spurious correlations, becomes larger. (§3.1)
IRMv1 works even when the invariant correlation is stronger than the spurious correlation. (§3.2)
IRMv1 learns an approximately invariant predictor when the underlying relationship is approximately invariant. (§3.3)
IRMv1 can be extended to an analogous setup for binary sentiment analysis with text inputs. (§
4)We publicly release our code at https://github.com/kakaobrain/irmempiricalstudy for reproducibility.
Consider a set of environments . For each environment , we assume a data distribution on and a risk function
for a convex and differentiable loss function
, such as crossentropy and mean squared error. Our goal is to find a predictor that minimizes the maximum risk over all environments, or the OOD risk:(1) 
A representation function is said to elicit an invariant predictor
if there exists a classifier
that is simultaneously optimal for all environments, i.e. for all . Given training environments , IRM learns an invariant predictor by solving the following bilevel optimization problem:(2)  
subject to 
Arjovsky et al. [2] use the theory of invariant causal prediction (ICP) [10] to illustrate that, for linear predictors and a linear underlying causal structure, an invariant predictor can be found using IRM as long as the training environments are sufficiently diverse and the underlying invariance is satisfied. The connection to ICP also yields a causal interpretation of IRM: it can discover causal structures from the underlying data distribution that can be extrapolated to OOD datasets.
Because (2) is highly intractable, particularly when is allowed to be nonlinear, Arjovsky et al. [2] proposed a tractable variant that approximates (2), called IRMv1:
(3) 
where for output dimension , is an
dimensional vector, and
is an dimensional all1 vector.is a hyperparameter that balances between predictive power over training tasks, i.e. the ERM loss, and the squared gradient norm, i.e. the IRM penalty. Note that, in
[2], the IRMv1 derivation focused on the case where, where the predictor outputs a single scalar (e.g., a logit). Throughout this paper, we build predictors that output
dimensional logits for class classification, such that the gradient in the penalty term is also dimensional.Importantly, the penalty term captures how much the invariant representation can be improved by locally scaling itself. For each and environment , the squared gradient norm penalty approximates “how close” is to a minimizer . Indeed, for strongly convex loss functions and fullrank representations , the penalty is zero if and only if is the unique minimizer . If the penalty term is zero across training environments, then the “classifier” is simultaneously optimal for all training environments, making an invariant predictor among training environments.
Note that (3) is now an objective with respect to only and can be optimized using gradient descent for nonlinear
such as a deep neural network. Section 3 of
[2] details how (3) approximates (2) in the case of linear leastsquares regression given representations and targets .In this section, we examine IRMv1 on several versions of Extended ColoredMNIST, a collection of synthetic image classification tasks derived from MNIST. Extended ColoredMNIST tweaks the original ColoredMNIST setup from [2], where the training set of MNIST is split into two environments and spurious correlations are introduced by associating specific colors with specific output classes. The overall pipeline for constructing a Extended ColoredMNIST dataset can be summarized as follows.
Randomly split the training data () into environments, . The test data () is considered to come from its own environment .
Corrupt the labels with probability
.Pair each output class with a unique color, e.g., , , and so on.
With probability , color the input image with the color paired with its (possibly corrupt) label. Otherwise, i.e., with probability , color the input image with a different color.
The goal of this setup is to build a set of training environments where the spurious correlation is stronger than the invariant correlation in data, but only the spurious correlation varies among the training environments. The spurious correlation is captured by the correlation between color and label, or , while the invariant correlation is captured by the correlation between shape and label, or . In this setting, we expect that ERM picks up the spurious correlation that appears strong during training and suffers in the test set where the spurious correlation is altered. On the other hand, we expect that IRM picks up the invariant correlation only, because it learns an invariant predictor.
The original ColoredMNIST setup from [2] is a special case of Extended ColoredMNIST. It consists of environments and a test environment, all with label corruption probabilities . The two training environments have coloring probabilities and respectively, while the test environment has . Labels are collapsed into two classes, for digits 04 and for digits 59. To color the image, the input image is represented in two channels, one for red and one for green. During training, the color of each input image is highly correlated to the binary label: 0 with green and 1 with red (with probability for each environment ). At test time, this correlation is reversed: 0 with red and 1 with green (with probability ).
In the following experiments, the training details follow those from the original implementation^{1}^{1}1https://github.com/facebookresearch/InvariantRiskMinimization
, unless noted otherwise. The representation model is parametrized as a multilayer perceptron (MLP) with ReLU activations, three layers, and an equal number (
) of hidden units in between layers, containing a total of 166,914 trainable parameters. The model is trained using gradient descent for 500 steps, with a fixed learning rate of and a L2 weight decay of . is set to but is fixed to for the firststeps. Input MNIST images were subsampled from 28x28 to 14x14. All reported accuracies are averaged over 10 random seeds, along with their standard deviations.
First, we examine how the OOD generalization performance of IRMv1 is affected by the difference between training environments in terms of their spurious correlations. For training environments, the gap between training environments is captured by , the difference in coloring probabilities. Since IRM exploits this gap in spurious correlations among the environments, we expect that the algorithm can perform well only if this gap is substantially greater than zero. Note that, if , then no algorithm can distinguish the invariant correlation from the (no longer) spurious one.
In Figure 1, we plot the train and test accuracies of ERM and IRMv1 against . For each value of , we define and . Since the average of coloring probabilities is , the spurious correlation is always stronger than the invariant correlation during training. All other settings are the same as the original ColoredMNIST – in particular, a random baseline achieves 50% accuracy and the optimal classifier achieves 75%.
As shown in the right plot of Figure 1, we find that the generalization performance of IRMv1 consistently improves as the gap between training environments grows larger. Note that IRMv1 outperforms ERM as soon as a nonzero gap exists and achieves abovechance accuracy for . As the gap grows larger, IRMv1’s test set accuracy consistently improves and eventually gets close to the accuracy (70.6%) of the grayscale model, which is the same model trained on data without spurious correlations (i.e., the colors). This suggests that IRMv1 benefits from more varying spurious correlations between training environments.
In the original ColoredMNIST experiment, by construction, the spurious correlation is stronger than the invariant correlation during training. In reality, the spurious correlation might exist but its degree might not be as strong as the invariant correlation. We pose two questions here: (i) Does the weak spurious correlation still hurt ERM’s generalization performance? (ii) If so, can IRMv1 help avoid this issue?
To answer these questions, we setup two Extended ColoredMNIST settings where the invariant correlation is stronger the spurious one. In the first setting, the spurious correlations for training environments are weakened to and , with label corruption probabilities kept at . In the second setting, we remove label corruption, i.e. , and the spurious correlations are unchanged from the original version, i.e. and . In both settings, the average spurious correlation () is now 0.15 lower than the invariant correlation (), and the gap between training environments () is fixed to . Also, is kept as in both settings.


Algorithm  25% Label Corruption  No Label Corruption  
Train  Test  Train  Test  
ERM  
IRMv1  ^{†}  ^{†}  
Random  
Optimal  
Grayscale  

In Table 1, we present train and test accuracies in the two settings. We also present random, oracle, and grayscale baselines in both settings. With 25% label corruption, we find that IRMv1 achieves the same accuracy (71.6%) as the grayscale model, suggesting that it effectively ignores the spurious correlation and predicts as well as the same model trained without the spurious correlation. In contrast, while ERM now achieves accuracy above chance (61.6%), it fails to completely ignore the spurious correlation when making its prediction, as evidenced by its relative high train and low test accuracies.
With no label corruption, however, we find that ERM (92.7%) outperforms IRMv1 (91.0%), even though IRMv1 needed much more training steps (10,000) for the training loss and penalty to stop decreasing. One possible explanation is that, because the invariant correlation is “too strong”, the difference in spurious correlations between training environments was comparatively not large enough for IRMv1 to exploit. This would suggest that IRMv1 is less effective when the invariant correlation is already strong enough, such as the case of no label corruption. Also, notice that ERM’s performance is still lower than the grayscale model (97.9%), suggesting that the spurious correlation is still problematic for ERM, although IRMv1 does not seem to solve this issue.
In the original ColoredMNIST and our previous generalizations, the association between shape and label was fully invariant — 75% with label corruption and 100% without across all environments. However, in reality, it may be unreasonable to assume that the association is completely invariant across training environments. Arjovsky et al. [2] states that, in cases when the data follows an approximately invariant model, IRM should return an approximately invariant solution, because its objective is a differentiable function with respect to the training environments. Our goal in this section is to empirically validate this claim.
We set up an experiment analogous to one done in Section 3.1, but we now vary the rate of label corruption across environments. Specifically, we evaluate IRMv1 and ERM on different values of , while keeping the average fixed to the test set, i.e. . We also fix and . As grows larger, the association between shape and label becomes less invariant.
Figure 2 shows mean accuracies of ERM (blue), IRMv1 (green), and random (orange) predictors across different values of (0.0 to 0.5 with 0.01 increments). First, we find that IRMv1 can achieve high test accuracy even when the association between shape and label is only approximately invariant. For example, when , IRMv1 achieves 70.4% test accuracy, which is even larger than the test accuracy when (66.2%). As long as the the association is not more variant than the spurious correlation between color and label, IRMv1 performs above chance and also well above ERM, which quickly overfits to color across all values of .
Interestingly, the overall behavior of test accuracy for IRMv1 (right) suggests that the algorithm gives more weight to whichever factor that is more invariant, in a nearly smooth manner. As the gap grows larger, IRMv1’s test accuracy becomes directly correlated to : high when the gap is small and low when the gap is large. In fact, this pattern in IRMv1 test accuracy is almost a smooth function of the gap, starting from high test accuracy (i.e., predictions made off of shape) eventually reaching the test accuracy of ERM (i.e., predictions made off of color). Also notable is that the accuracy is close to the random baseline (50%) precisely when . We posit that, at this point, the algorithm either discards both factors, as they are equally noninvariant, or weighs both factors equally, which would cancel out in terms of test accuracy. In either case, it is interesting to see that IRMv1 chooses not to favor color, which is a stronger indicator of the label than shape in the training set.
Some realistic datasets for IRM may contain examples sourced from many more environments than two. When datasets are collected from multiple sources, which is common for many benchmark datasets, they are likely separable into many environments, each with a different degree of spurious correlation.
To examine the performance of IRMv1 with environments, we follow the Extended ColoredMNIST data construction pipeline we described earlier to build datasets with environments. Each environment possesses a unique probability that the label is correlated to a specific color. In all cases, we set maximum and minimum values of to and , respectively, and spread out the environments evenly. For environments, we use , , and . We also test out uneven gaps when using , , , , and . Note that the average of environment probabilities is always smaller than , meaning that ERM performs poorly (%) in all of these environments.


Algorithm  # Environments  Accuracy  
Train  Test  
ERM  2  
IRMv1  2  
IRMv1  3  
IRMv1  5  
IRMv1  5 (uneven)  
IRMv1  10  

Our results are summarized in Table 2. Overall, we find that IRMv1 can achieve high test accuracy (68.370.1%) with 3, 5, or 10 environments, spread out evenly or unevenly. Also, the performance seems slightly better for less numbers of environments, although not significantly. We posit that the performance might degrade for more environments as the average gap between any two environments gets closer, because we fixed the maximum gap.
Our final Extended ColoredMNIST examines the performance of IRMv1 for multiple outcomes. This is important because many realworld datasets involve multidimensional outputs – most notably, multiclass classification tasks require multidimensional logits as outputs. Yet, for the sake of clarity, the original formulation of IRMv1 in [2] focused on binary classification with sigmoidal logits, leading to a scalar output. Here, we treat ColoredMNIST as a multiclass classification task, per the extended derivation of IRMv1 in (3).
Analogous to ColoredMNIST with two classes, we construct a class ColoredMNIST by assigning a unique color that is highly correlated to each output class during training and shifting it for the test set. To prevent introducing unwanted correlation structures, we assign a unique input channel for each color. Note that this makes the first layer of the MLP contain more parameters for larger . We test four values of : ( for digits 04, for 59); ( for 01, for 23, …, for 89); and (each digit is its own class).


Algorithm  # Outcomes  Accuracy  
Train  Test  
ERM  2  
IRMv1  
Random  
Grayscale  
ERM  5  ^{‡}  ^{‡} 
IRMv1  
Random  
Grayscale  
ERM  10  ^{‡}  ^{‡} 
IRMv1  ^{†}  ^{†}  
Random  
Grayscale  

Our results are summarized in Table 3. Overall, we find that IRMv1 still generalizes significantly better than ERM on multiple outcomes (). This shows that the IRMv1 penalty for multidimensional outputs is still effective, so long as the underlying causal structure is preserved. We do note that the test accuracy for IRMv1 degrades as more output classes are used, unlike the grayscale model that retains its test accuracy. This suggests that, although effective, the IRMv1 penalty may become less effective when its squared gradients are summed over more dimensions. One possibility is that the gradient norm penalty does not scale each of the dimensions adaptively, although it is unclear how to weigh each dimension properly. Another possibility is that evaluating the gradient at the all1 vector is problematic, although the IRMv1 derivation for leastsquares in [2] suggests that this shouldn’t be an issue.
As described in [2]
, the fact that ERM is prone to absorbing biases and spurious correlations from the training data is a fundamental problem across all machine learning applications. Natural language processing (NLP) is no exception: several recent papers
[6, 7, 8] repeatedly pointed out the presence of spurious correlations in text classification tasks, often in the form of specific words being highly correlated with specific labels, and how NLP models actively exploit them. As a result, stateoftheart models for NLP often make trivial mistakes and fail to generalize outofdistribution.With this in mind, we apply the data construction pipeline for Extended ColoredMNIST to a text classification dataset and evaluate the performances of both ERM and IRMv1. We start with the Stanford Sentiment Treebank (SST2) [11], a standard benchmark dataset for binary sentiment analysis, and use an analogous pipeline as follows.
Randomly split the SST2 training data () into environments, . The SST2 validation data () is also randomly split, but into environments: and .
Corrupt the labels with probability .
Pair each output class with a punctuation mark: positive with a period (.) and negative with an exclamation mark (!).
Remove any existing punctuation mark at the end of each input sentence.
With probability , punctuate the input sentence with the mark paired with its (possibly corrupt) label. Otherwise, i.e., with probability , punctuate the input sentence with the other mark.
There are two key differences from the pipeline for ColoredMNIST. First, we now keep a separate test set that is assumed to have come from each of the training environments. This allows us to more precisely measure how much the model is picking up the spurious correlation, as the examples from this test set are unseen but has the same degree of spurious correlation as its corresponding training environment. In particular, a model’s test set accuracy from environment cannot surpass unless the model pays attention to the spurious correlation. The final test set, now renamed as coming from environment , corresponds to the test set we had in ColoredMNIST. Although the same could have been done for ColoredMNIST in Section 3, we preserved the original pipeline by [2] to maintain consistency.
Second, because the inputs are now sequences of words, we need a different way to “color” them based on the label. As shown above, we introduce associations between each output class and a specific punctuation mark. Adding a punctuation mark at the end allows us to preserve both the meaning and grammaticality of the input sentence.
Aside from these two differences, we use the original ColoredMNIST setup for the following experiments: training environments with , , and , and label corruption across environments of .
Because stateoftheart models on SST2 can be quite complex for our tests, we resort to a simple bagofwords (BOW) model with averaged word embeddings. We use the average of the 300dimensional GloVe [9] word vectors as inputs, and train a 3layer (3003002) perceptron with ReLU activations. The model contains 181,202 trainable parameters, which is comparable to the MLP model used for ColoredMNIST. We still train with full gradient descent for 500 steps, but use hyperopt^{2}^{2}2https://github.com/hyperopt/hyperopt
with treestructured Parzen estimators (TPE)
[4] for a hyperparameter search over the learning rate (1e3, 2e3, 5e3, 1e2), L2 weight decay (1e4, 5e4, 1e3, 1e2), and the penalty weight for IRMv1 (5k, 7.5k, 10k, 20k, 50k). Across 50 hyperparameter configurations, we choose the one that gives the highest minimum accuracy over the three test sets (averaged over 5 trials). We then use the best hyperparameter configuration to report the final mean accuracy.


Algorithm  Test Accuracy  
ERM  
IRMv1  
Majority  
Oracle  

In Table 4, we report the mean accuracies for ERM and IRMv1, as well as for the majority vote and oracle (ERM without input perturbations, or the “grayscale” model) baselines. We first find that the ERM model is highly susceptible to the oneword spurious correlation present in the training environments, resulting in a worsethanchance OOD accuracy (30.4%), analogous to ColoredMNIST. More importantly, we find that the IRMv1 penalty works to remove the effect of this spurious but varying oneword correlation and achieves almost as high OOD accuracy (61.4%) as the oracle model (61.5%). Note that, unlike ColoredMNIST, the BOW model is known to achieve around 80% accuracy without label corruption [13], and as a result, the oracle model’s accuracy after label corruption () are lower (around 61%) than that of ColoredMNIST (around 71%). These results indicate the IRMv1 algorithm can indeed be extended to textbased inputs.
To deepen our understanding of IRM, we examined the generalization performance of IRMv1 across several extensions of ColoredMNIST. Overall, our findings were optimistic. We found that IRMv1 is capable of detecting and removing small variations across environments and that it can be helpful even when the invariant correlations are stronger than the spurious ones. We also found that IRMv1 learns approximately invariant predictors when the underlying relationship is approximately invariant. Finally, we found that IRMv1 can be extended to settings with multiple environments or outcomes as well as text classification tasks involving spurious correlations.
We believe that these results serve as initial steps toward making IRM more broadly and effectively applicable. One important future direction is to figure out what kinds of multienvironment settings exist in realworld datasets and how the IRM framework can help. This includes identifying patterns of varying spurious correlations across different tasks, such as the wordlabel correlations we described in text classification. Another, perhaps more challenging, direction is to identify ways to build meaningful multienvironment settings. Even though nature does not shuffle data [2], many standard benchmark datasets already come in a singleenvironment manner, making it infeasible to distinguish invariant and spurious correlations. It would be important to have better insights on both how to effectively construct multienvironment datasets and how to identify multiple environments within an existing dataset. Finally, developing more stable approximations of the IRM framework, as pointed out in [1], would be crucial for scaling IRM to larger datasets and models.
Y. C. would like to thank Chiheon Kim for helpful discussion and comments.
Proceedings of the European Conference on Computer Vision (ECCV)
, pp. 456–473. Cited by: §1.Right for the wrong reasons: diagnosing syntactic heuristics in natural language inference
. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, Florence, Italy, pp. 3428–3448. External Links: Link, Document Cited by: §4.Causal inference by using invariant prediction: identification and confidence intervals
. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5), pp. 947–1012. Cited by: §1, §2.1.
Comments
There are no comments yet.