Training data in many emerging applications is necessarily limited, fragmented, or otherwise heterogeneous. It is therefore important to ensure that model predictions derived from such data generalize substantially beyond where the training samples lie. For instance, in molecule property prediction Wu et al. (2018), models are often evaluated under scaffold split, which introduces structural separation between the chemical spaces of training and test compounds. In protein homology detection Rao et al. (2019), models are evaluated under protein superfamily split where entire evolutionary groups are held out from the training set, forcing models to generalize across larger evolutionary gaps.
The key technological challenge is to be able to estimate models that can extrapolate beyond their training data. The ability to extrapolate implies a notion of invariance to the differences between the available training data and where predictions are sought. A recently proposed approach known as invariant risk minimization (IRM)Arjovsky et al. (2019) seeks to find predictors that are simultaneously optimal across different such scenarios (called environments). Indeed, one can apply IRM with environments corresponding to molecules sharing the same scaffold Bemis and Murcko (1996) or proteins from the same family El-Gebali et al. (2019) (see Figure 1). However, this is challenging since, for example, scaffolds are substructure descriptors (combinatorially defined) and can often uniquely identify each example in the training set. Another difficulty is that IRM collapses to empirical risk minimization (ERM) if the model can achieve zero training error across the environments – a scenario typical with over-parameterized models Zhang et al. (2016).
To address these difficulties we propose a new method called regret minimization (RGM). This new approach seeks to find a feature-predictor combination that generalizes well to unseen environments. We quantify generalization in terms of regret that guides the feature extractor , encouraging it to focus on information that enables generalization. The setup is easily simulated by using part of the training set as held-out environments . Specifically, our regret measures how feature extractor enables a predictor trained without to perform well on in comparison to an oracle with hindsight access to . Since our regret measures the ability to predict, it need not collapse even with powerful models. To handle combinatorial environments, we appeal to domain perturbation and introduce two additional, dynamically defined environments. The perturbed environments operate over the same set of training examples, but differ in terms of their associated representations . The idea is to explicitly highlight to the predictor domain variability that it should not rely on.
Our method is evaluated on both synthetic and real datasets such as molecule property prediction and protein classification. We illustrate on the synthetic dataset how RGM overcomes some of the IRM challenges. On the real datasets, we compare RGM with various domain generalization techniques including CrossGrad Shankar et al. (2018), MLDG Li et al. (2018a) as well as IRM. Our method significantly outperforms all these baselines, with a wide margin on molecule property prediction (COVID dataset: 0.654 versus 0.402 AUC; BACE dataset: 0.590 versus 0.530 AUC).
2 Related work
, DG assumes samples from target domains are not available during training. DG has been widely studied in computer visionGhifary et al. (2015); Motiian et al. (2017); Li et al. (2017, 2018b, 2018c), where domain shift is typically caused by different image styles or dataset bias Khosla et al. (2012). As a result, each domain contains fair amount of data and the number of distinct domains is relative small (e.g., commonly adopted PACS and VLCS benchmarks Fang et al. (2013); Li et al. (2017) contain only four domains). We study domain generalization in combinatorially defined domains, where the number of domains is much larger. For instance, in a protein homology detection benchmark Fox et al. (2013); Rao et al. (2019), there are over 1000 domains defined by protein families. Our method is related to prior DG methods in two aspects:
Simulated domain shift: Meta-learning based DG methods Li et al. (2018a); Balaji et al. (2018); Li et al. (2019a, b); Dou et al. (2019) simulate domain shift by dividing source domains into meta-training and meta-test sets. They seek to minimize model’s generalization error on meta-test domains after training on meta-training domains. Similarly, our formulation also creates held-out domains during training and minimizes model’s regret. However, our objective enforces a stronger requirement for domain generalization: we not only minimizes model’s generalization error on held-out domains, but also require it to be optimal, i.e., performing as well as the best predictor trained on held-out domains.
Domain augmentation: CrossGrad Shankar et al. (2018) augments the dataset with domain-guided perturbations of input examples for domain generalization. Likewise, Volpi et al. (2018) augments the dataset with adversarially perturbed examples served as a fictitious target domain. Our domain perturbation method in §4 is closely related to CrossGrad, but it operates over learned features as our inputs are discrete. Moreover, our domain perturbation is only used to compute regret in our RGM objective. Different from data augmentation, our predictor is not directly trained on the perturbed examples.
Learning invariant representation One way of domain extrapolation is to enforce an appropriate invariance constraint over learned representations Muandet et al. (2013); Ganin et al. (2016); Arjovsky et al. (2019). Various strategies for invariant feature learning have been proposed. They can be roughly divided into three categories:
Domain adversarial training (DANN) Ganin et al. (2016) enforces the latent representation to have the same distribution across different domains (environments) . If we denote by the data distribution in environment , then we require for all . With some abuse of notation, we can write this condition as . A single predictor is learned based on , i.e., all the domains share the same predictor. As a result, the predicted label distribution will also be the same across the domains. This can be problematic when the training and test domains have very different label distributions Zhao et al. (2019).
Conditional domain adversarial network (CDAN) Long et al. (2018); Li et al. (2018c) instead conditions the invariance criterion on the label, i.e., for all and . In other words, we aim to satisfy the independence statement . The formulation allows the label distribution to vary between domains, but the constraint becomes too restrictive when domains are combinatorially defined and many domains have only one example (see Figure 1). In this case, degenerates to a Dirac distribution and the constraint will require the representation to map all
to the same vector within each class. As a result, CDAN (as well as DANN) require each domain to have fair numbers of examples in practice.
Invariant risk minimization (IRM) Arjovsky et al. (2019) requires that the predictor operating on is simultaneously optimal across different environments. The associated conditional independence criterion is . In other words, knowing the environment should not provide any additional information about beyond the features . However, IRM tend to collapse to ERM when the model is over-parameterized and perfectly fits the training set (see §3). Moreover, when most of the domains can uniquely specify in the training set, would act similarly to and the IRM principle reduces to , which is not a useful criterion for domain extrapolation. We propose to handle this issue via domain perturbation (see §4).
3 Domain extrapolation via regret minimization
The IRM principle provides a useful way to think about domain extrapolation but it does not work well with strong predictors. Indeed, a zero training error reduces the IRM criterion to standard risk minimization or ERM. The main reason for this collapse is that the simultaneous optimality condition in IRM is not applied in a predictive sense (as regret). To see this, consider a training set divided into environments . Let be the empirical loss of predictor operating on feature representation , i.e., , in environment . The specific form of the loss depends on the task. IRM finds and as the solution to the constrained optimization problem:
where . The key simultaneous optimality constraint can be satisfied trivially if the model achieves zero training error across the environments, i.e.
. The setting is not uncommon with over-parameterized neural networks even if labels were set at randomZhang et al. (2016).
Regret minimization We can replace the simultaneous optimality constraint in IRM in terms of a predictive regret. This is analogous to one-step regret in on-line learning but cast here in terms of held-out environments. We calculate this regret for each held-out environment as the comparison between the losses of two auxiliary predictors that are trained with and without access to . Specifically, we define the regret as
where the two auxiliary predictors are obtained from
Note that the oracle predictor is trained and tested on the same environment while is estimated based on all the environments except but evaluated on . The regret is always non-negative since it is impossible for to beat the oracle. Note that, unlike in IRM, even when and are strong enough to ensure zero training loss across environments they are trained on, i.e., , the combination may still generalize poorly to a held-out environment giving . In fact, the regret expresses a stronger requirement that should be nearly as good as the best predictor with hindsight, analogously to on-line regret.
Note that does not depend on the predictor we are seeking to estimate; it is a function of the representation as well as the auxiliary pair of predictors and . For notational simplicity, we suppress the dependence on and . The overall regret expresses our stated goal of finding a representation that facilitates extrapolation to each held-out training environment. Our RGM objective then balances the ERM loss against the predictive regret: representation and predictor are found by minimizing
Optimization Our regret minimization (RGM) can be thought of as finding a stationary point of a multi-player game with several players: , as well as auxiliary predictors and . Our predictor and representation find their best response strategies by minimizing
assuming that and remain fixed. The auxiliary predictors minimize
where . The auxiliary objectives depend on the representation but this is not exposed to , reflecting an inherent asymmetry in the multi-player game formulation.
The RGM game objective is solved via stochastic gradient descent. In each step, we randomly choose an environmentand sample a batch from . We also sample an associated batch from the other environments . The loss is computed over . is updated on the basis of batch approximation to . The losses defining the regret, i.e., and , are naturally evaluated based on examples in only. The gradients for and are implemented by a gradient reversal layer (Ganin et al., 2016). The setup allows us to optimize all the players in a single forward-backward pass operating on the two batches (see Figure 2).111RGM needs to learn additional predictors that are not included in IRM. The introduction of these additional predictors brings little overhead in practice, however, because the predictors are much simpler than the feature extractor and evaluating over the two batches can be done in parallel.
4 Extrapolation to combinatorially defined domains
Both our proposed regret minimization as well as IRM assume that the set of environments are given as input, provided by the user. The environments exemplify nuisance variation that needs to be discounted so they play a critical role in determining whether the approach is successful. The setting becomes challenging when the natural environments are combinatorially defined. For example, in molecule property prediction, each environment is defined by a scaffold, which is a subgraph of a molecule (see Figure 1). Since scaffolds are combinatorial descriptors, they often uniquely identify each molecule in the training set. It is not helpful to create single-example environments as the model would see any variation from one example to another as nuisance, not able to associate nuisance primarily to scaffold variation.
A straightforward approach to combinatorial or large numbers of environments is to cluster them into fewer, coarser sets, and apply RGM over the coarse environments. For simplicity, we cluster the training environments into just two coarse environments . The advantage is that we only need to realize two auxiliary predictors instead of predictors. The construction of the coarse environments depends on the application (see §5).
Domain perturbation While using coarse environments is computationally beneficial, this clearly loses the ability to highlight finer nuisance variation from scaffolds or protein families. To counter this, we introduce and measure regret on additional environments created specifically to highlight fine-grained variation of scaffolds or protein families but in an efficient manner. We define these additional environments via perturbations, as discussed in detail below. Both and its associated perturbed environment serve as held-out environments to the predictor . These give rise to regret terms relative to oracles that can fit specifically to each environment, now including . These additional regret terms will further drive the feature representation . The goal is to learn to generalize well to finer-grained variations of scaffolds or protein families that we may encounter at test time.
We propose additional environments through gradient-based domain perturbations. Specifically, for each coarse environment , we construct another environment whose representations are perturbed: . Note that and are defined over the same set of examples but differ in the representation that the predictors operate on when calculating the regret. The perturbation
is defined through a parametric scaffold (or protein family) classifier. The associated classification loss is , where is the scaffold (or protein family) label of (see Figure 3a). We define the perturbation in terms of the gradient:
where is a step size parameter. The direction of perturbation creates a modified representation that contains less information about the scaffold (or protein family) than the original representation . The impact on domain classifier output is illustrated in Figure 3b. Note that the variation between and highlights how finer scaffold information remains in the representation; the associated regret terms then require that this variation does not affect quality of prediction.
Integration with RGM We augment the RGM objective in Eq.(4) with two additional terms. First, the scaffold (or protein family) classifier is trained together with the feature mapping to minimize
Second, we add regret terms specific to the perturbed environments to encourage the model to extrapolate to them as well. The new objective for the main players , , and then becomes
where we have introduced a new oracle predictor for perturbed environment , in addition to for the original environment (see Figure 3c). Note that minimizes a separate objective , which does not include the perturbed examples. Perturbations represent additional simulated test scenarios that we wish to generalize to. The training procedure is shown in Algorithm 2.
Remark While the perturbation is defined on the basis of as well as the classifier , we do not include the dependence during back-propagation. We verified that incorporating this higher order gradient would not improve our empirical results. Another subtlety in the objective is that is adjusted to also help the classifier . In other words, the representation is in part optimized to retain information about molecular scaffolds or protein families. This encourages the perturbation to be meaningful and relevant to downstream tasks.
We evaluate our method on three tasks. We first construct a synthetic task to verify the weakness of IRM and study the behaviour of RGM. Then we test our method on protein classification and molecule property prediction tasks where the environments are combinatorially defined. In both tasks, we test our method under two settings: 1) RGM combined with domain perturbation (named as RGM-DP); 2) standard RGM trained on the coarse environments used in RGM-DP.
Baselines On the synthetic dataset, we mainly compare with IRM Arjovsky et al. (2019). For the other two tasks, we compare our method with ERM (environments aggregated) and more domain extrapolation methods:
DANN Ganin et al. (2016), CDAN Long et al. (2018) and IRM Arjovsky et al. (2019) seek to learn domain-invariant features. As mentioned in section 2, these methods require each domain to have fair amount of data. Thus, they are trained on the coarse environments used in RGM-DP instead of the original combinatorial environments (i.e., molecular scaffolds and protein superfamilies).
MLDG Li et al. (2018a) simulates domain shift by dividing domains into meta-training and meta-testing.
CrossGrad Shankar et al. (2018) augments the dataset with domain-guided perturbations of inputs. Since it requires the input to be continuous, we perform domain perturbation on learned features instead.
DANN, CDAN and IRM are trained on the coarse environments and comparable to standard RGM. MLDG and CrossGrad are trained on combinatorial environments and comparable to RGM-DP.
5.1 Synthetic data
Data We first compare the behavior of IRM and RGM on an inter-twinning moons problem Ganin et al. (2016), where the domain shift is caused by rotation (see Figure 4). The training set contains two environments . As for , we generate a lower moon and an upper moon labeled 0 and 1 respectively, each containing 1000 examples. is constructed by rotating all examples in by . Likewise, our test set contains examples rotated by . As validation set plays a crucial role in domain generalization, we consider two different ways of constructing the validation set :
Out-of-domain (OOD) validation: , which rotates all examples in by .
In-domain validation: We create by adding Gaussian noise to examples in . We experiment with in-domain validation because OOD validation is not always available in practice.
Setup The feature extractor and predictor
is a two-layer MLP with hidden dimension 300 and ReLU activation. For RGM, we set. For IRM, we use the official implementation from Arjovsky et al. (2019) based on gradient penalty. Both methods are optimized by Adam with regularization weight , where IRM performs the best on the OOD validation set.
Results Our results are shown in Figure 5. Our method significantly outperforms IRM (84.6% vs 74.7% with the OOD validation). IRM test accuracy is close to ERM under in-domain validation setting. IRM is able to outperform ERM under OOD validation setting because it provides additional extrapolation signal. For ablation study, we train models with different regularization weight and OOD validation so that models yield zero training error. As shown in Figure 5 (right), IRM’s test accuracy becomes similar to ERM when as its training accuracy reaches 100%. This shows that IRM will collapse to ERM when model perfectly fits the training set. In contrast, RGM test accuracy is around 80.0% even when and training error is zero.
5.2 Molecule property prediction
Data The training data is a collection of pairs , where is a molecular graph and is its property label (binary). The environment of each compound is defined as its Murcko scaffold Bemis and Murcko (1996), which is a subgraph of with side chains removed. We consider the following four datasets:
Tox21, BACE and BBBP are three classification datasets from the MoleculeNet benchmark Wu et al. (2018), which contain 7.8K, 1.5K and 2K molecules respectively. Following Feinberg et al. (2019), we split each dataset based on molecular weight (MW). This setup is much harder than commonly used random split as it requires models to extrapolate to new chemical space. The training set consists of simple molecules with . The test set molecules are more complex, with . The validation set contains molecules with (more details are in the appendix).
COVID-19: During recent pandemic, many research groups released their experimental data of antiviral activities against COVID-19. However, these datasets are heterogeneous due to different experimental conditions. This requires our model to ignore spurious correlation caused by dataset bias in order to generalize. We consider three antiviral datasets from PubChem 26, Diamond Light Source 28 and Jeon et al. (2020). The training set contains 10K molecules from PubChem and 700 compounds from Diamond. The validation set contains 180 compounds from Diamond. The test set consists of 50 compounds from Jeon et al. (2020), a different data source from the training set.
Model The feature extractor is a graph convolutional network Yang et al. (2019) which translates a molecular graph into a continuous vector. The predictor is a MLP that takes as input and predicts the label. Since scaffold is a combinatorial object with a large number of possible values, we train the environment classifier by negative sampling. Specifically, for a given molecule with scaffold , we randomly sample other molecules and take their associated scaffolds as negative examples. Details of model architecture are discussed in the appendix.
RGM setup For RGM-DP, we construct two coarse environments as the following. On the COVID-19 dataset, consists of 700 compounds from the Diamond dataset and is the PubChem dataset. The two coarse groups are created to highlight the dataset bias. For other datasets, consists of molecules with and . We set in all datasets.
Results Following standard practice, we report AUROC score averaged across five independent runs. As shown in Table 1, our methods significantly outperformed other baselines (e.g., 0.654 vs 0.402 on COVID). On the COVID dataset, the difference between RGM and RGM-DP is small because the domain shift is mostly caused by dataset bias rather than scaffold changes. Indeed, RGM-DP shows a clear improvement over standard RGM on the BACE dataset (0.590 vs 0.532), since the domain shift is caused by scaffold changes (i.e., complex molecules usually have much larger scaffolds).
|Top 1||Top 5||AUROC||AUROC||AUROC||AUROC|
5.3 Protein homology detection
Data We evaluate our method on a remote homology classification benchmark used in Rao et al. (2019). The dataset consists of pairs , where is a protein sequence and is its fold class. It is split into 12K for training, 736 for validation and 718 for testing by Rao et al. (2019). Importantly, the provided split ensures that there is no protein superfamily that appears in both training and testing. Each superfamily represents an evolutionary group, i.e., proteins from different superfamilies are structurally different. This requires models to generalize across large evolutionary gaps. In total, the dataset contains 1823 environments defined by protein superfamilies.
Model The protein encoder contains two modules:
is a TAPE protein embedding learned by a pre-trained transformer networkRao et al. (2019); is a LSTM network that embeds associated protein secondary structures and other features. The predictor is a feed-forward network that takes as input and predicts its fold label. The environment classifier also takes as input and predicts the superfamily label of (out of 1823 classes).
RGM setup For RGM-DP, we construct two coarse environments as the following. contains all protein superfamilies which have less than 10 proteins and . The coarse environments are divided based on the size of superfamilies because the validation set mostly contains protein superfamilies of small size. We set and .
Results Following Rao et al. (2019), we report the top-1 and top-5 accuracy in Table 1. For reference, the top-1 and top-5 accuracy of TAPE transformer Rao et al. (2019) are 21.0% and 37.0%.222Rao et al. (2019) did not release their best-performing pre-trained LSTM network. Thus we use their pre-trained transformer in our experiments and report its accuracy for reference. Our ERM baseline achieves better results as we incorporate additional features. The proposed RGM-DP outperforms all the baselines in both top-1 and top-5 accuracy. The vanilla RGM operating on coarse environments also outperforms other baselines in top-1 accuracy. Indeed, RGM-DP performs better than RGM because it operates on protein superfamilies and receives stronger extrapolation signal.
In this paper, we propose regret minimization for domain extrapolation, which seeks to find a predictor that generalizes as well as an oracle that would have hindsight access to unseen domains. Our method significantly outperforms all baselines on both synthetic and real-world tasks.
Among many benefits, the proposed algorithm advances state-of-the-art in drug discovery. As the current COVID pandemics illustrates, the lack of quality training data hinders utilization of ML algorithms in search for antivirals. This data issue is not specific to COVID, and is common in many therapeutic areas. The proposed approach enables us to effectively utilize readily available, heterogeneous data to model bioactivity, reducing prohibitive cost and time associated with traditional drug discovery workflow. Currently, the method is utilized for virtual screening of COVID antivirals. We cannot see negative consequences from this research: at worst, it will degenerate to the performance of the base algorithm, the model aims to improve. In terms of bias, the algorithm is explicitly designed to minimize the impact of nuisance variations on model prediction capacity.
- Invariant risk minimization games. arXiv preprint arXiv:2002.04692. Cited by: 1st item.
- Invariant risk minimization. arXiv preprint arXiv:1907.02893. Cited by: §B.1, §1, 3rd item, §2, 1st item, §5.1, §5.
- Metareg: towards domain generalization using meta-regularization. In Advances in Neural Information Processing Systems, pp. 998–1008. Cited by: 1st item.
- The properties of known drugs. 1. molecular frameworks. Journal of medicinal chemistry 39 (15), pp. 2887–2893. Cited by: §1, §5.2.
- A theory of learning from different domains. Machine learning 79 (1-2), pp. 151–175. Cited by: §2.
- Learning protein sequence embeddings using information from structure. arXiv preprint arXiv:1902.08661. Cited by: Figure 1.
- Learning from multiple sources. Journal of Machine Learning Research 9 (Aug), pp. 1757–1774. Cited by: §2.
- Domain generalization via model-agnostic learning of semantic features. In Advances in Neural Information Processing Systems, pp. 6447–6458. Cited by: 1st item.
- The Pfam protein families database in 2019. Nucleic Acids Research 47 (D1), pp. D427–D432. External Links: Cited by: §1.
- Unbiased metric learning: on the utilization of multiple datasets and web images for softening bias. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1657–1664. Cited by: §2.
Step change improvement in admet prediction with potentialnet deep featurization. arXiv preprint arXiv:1903.11789. Cited by: 1st item.
- SCOPe: structural classification of proteins—extended, integrating scop and astral data and classification of new structures. Nucleic acids research 42 (D1), pp. D304–D309. Cited by: §2.
- Domain-adversarial training of neural networks. The Journal of Machine Learning Research 17 (1), pp. 2096–2030. Cited by: 1st item, §2, Figure 2, §3, 1st item, §5.1.
Domain generalization for object recognition with multi-task autoencoders. In Proceedings of the IEEE international conference on computer vision, pp. 2551–2559. Cited by: §2.
- Identification of antiviral drug candidates against sars-cov-2 from fda-approved drugs. bioRxiv. Cited by: 2nd item.
- Undoing the damage of dataset bias. In European Conference on Computer Vision, pp. 158–171. Cited by: §2.
- Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision, pp. 5542–5550. Cited by: §2.
Learning to generalize: meta-learning for domain generalization.
Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §1, 1st item, 2nd item.
- Episodic training for domain generalization. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1446–1455. Cited by: 1st item.
Domain generalization with adversarial feature learning.
Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5400–5409. Cited by: §2.
- Deep domain generalization via conditional invariant adversarial networks. In Proceedings of the European Conference on Computer Vision (ECCV), pp. 624–639. Cited by: 2nd item, §2.
- Feature-critic networks for heterogeneous domain generalization. arXiv preprint arXiv:1901.11448. Cited by: 1st item.
- Conditional adversarial domain adaptation. In Advances in Neural Information Processing Systems, pp. 1640–1650. Cited by: 2nd item, 1st item.
- Unified deep supervised domain adaptation and generalization. In Proceedings of the IEEE International Conference on Computer Vision, pp. 5715–5725. Cited by: §2.
- Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. Cited by: §2.
-  National center for biotechnology information. pubchem database. source=the scripps research institute molecular screening center, aid=1706. PubChem. Note: https://pubchem.ncbi.nlm.nih.gov/bioassay/1706 Cited by: 2nd item.
Evaluating protein transfer learning with tape. In Advances in Neural Information Processing Systems, Cited by: §B.3, §B.3, §1, §2, §5.3, §5.3, §5.3, footnote 2.
-  (2020) SARS-cov-2 main protease structure and xchem fragment screen. Diamond Light Source. Note: www.diamond.ac.uk/covid-19/for-scientists/Main-protease-structure-and-XChem Cited by: 2nd item.
- Generalizing across domains via cross-gradient training. arXiv preprint arXiv:1804.10745. Cited by: §1, 2nd item, 3rd item.
- Generalizing to unseen domains via adversarial data augmentation. In Advances in Neural Information Processing Systems, pp. 5334–5344. Cited by: 2nd item.
- MoleculeNet: a benchmark for molecular machine learning. Chemical science 9 (2), pp. 513–530. Cited by: Figure 1, §1, 1st item.
- How powerful are graph neural networks?. arXiv preprint arXiv:1810.00826. Cited by: 2nd item.
- Analyzing learned molecular representations for property prediction. Journal of chemical information and modeling 59 (8), pp. 3370–3388. Cited by: §B.2, §5.2.
Understanding deep learning requires rethinking generalization. arXiv preprint arXiv:1611.03530. Cited by: §1, §3.
- On learning invariant representation for domain adaptation. arXiv preprint arXiv:1901.09453. Cited by: 1st item.
Appendix A Additional analysis of IRM
In section 3, we discussed that IRM may collapse to ERM when the predictor is powerful enough to perfectly fit the training set. Under some additional assumptions, we can further show that the ERM optimal predictor is optimal for IRM even if the model has non-zero training error.
In particular, we assume that the conditional distribution is environment-dependent and the environment can be inferred from alone, i.e.,
For molecules and proteins, the second assumption is valid because the environment labels (scaffold, protein family) can be inferred from . Under this assumption, we can rephrase the IRM objective as
We claim that IRM will collapse to ERM when the representation is label preserving:
A representation is called feasible if it retains all the information about the label , i.e., .
Under the assumption of Eq.(11), for any feasible , its associated ERM optimal predictor is also optimal under IRM.
Given any feasible representation , the ERM optimal predictor is
To see that is optimal, consider
where Eq.(16) holds because is feasible. The ERM optimal also satisfies the IRM constraints because
Thus is also simultaeously optimal across all environments. ∎
Remark The above analysis implies that the learned representation should not be label-preserving in order for IRM to find a different solution from ERM. As a corollary, must be non-injective. An injective representation is feasible since (assuming is deterministic). We discuss several injective representations relevant to this paper:
Identity mapping: Ahuja et al.  considered modeling in their IRM formulation. Since an identity mapping preserves all the environment information, it may be not ideal for IRM.
Graph convolutional network (GCN): Xu et al.  suggest to use injective aggregation functions so that GCNs are powerful enough to distinguish different graphs (i.e., mapping different graphs to distinct representation . In contrast, our analysis suggests the use of non-injective aggregation functions in GCNs in the context of IRM and domain extrapolation.
Indeed, the above analysis assumes that the predictor is powerful enough to accommodate . We leave the analysis of parametrically constrained to future work.
Appendix B Experimental Details
b.1 Rotated moon dataset
Data The data generation script is provided in the supplementary material.
b.2 Molecule property prediction
Data The four property prediction datasets are provided in the supplementary material, along with the train/val/test splits. The size of the training, validation and test sets are listed in Table 2.
Model hyperparameters For the feature extractor , we adopt the GCN implementation from Yang et al. . We use their default hyperparameters across all the datasets and baselines. Specifically, the GCN contains three convolution layers with hidden dimension 300. The predictor
is a two-layer MLP with hidden dimenion 300 and ReLU activation. The model is trained with Adam optimizer for 30 epochs with batch size 50 and learning rate linearly annealed fromto .
The environment classifier is a MLP that maps a compound or its scaffold to a feature vector. The model is trained by negative sampling since scaffold is a combinatorial object. Specifically, for a given molecule in a mini-batch , we treat other molecules in the batch and take their associated scaffolds
as negative examples. The probability thatis mapped to its correct scaffold is then defined as
The environment classification loss is for a mini-batch . The classifier is a two-layer MLP with hidden dimension 300 and ReLU activation.
For RGM and RGM-DP, we consider and and select the best hyper-parameter for each dataset. consistently works the best across all the datasets.
b.3 Protein homology detection
Data The protein homology dataset is downloaded from Rao et al. . Each protein is represented by a sequence of amino acids, along with the predicted secondary structure labels, predicted solvent accessibility labels, and alignment-based features. For RGM-DP, the two coarse groups have 8594 and 3718 examples respectively.
Model hyperparameters The protein encoder . is a 768-dimensional TAPE embedding given by a pre-trained transformer Rao et al. . is a bidirectional LSTM that embeds the secondary structures, solvent accessibility and alignment-based features. The LSTM has one recurrent layer and its hidden dimension is 300. The predictor is a linear layer which worked better than MLP under ERM. The environment classifier is a two-layer MLP whose hidden size is 300 and output size is 1823 (the number of protein superfamilies).
The model is trained with Adam optimizer for 10 epochs, with batch size 32 and learning rate linearly annealed from to . For RGM and RGM-DP, we consider and and work the best on the validation set.