Machine learning models are often deployed in applications where the test (inference time) data distributions differ from their training dataset. For example, a model trained on data from one hospital is used for prediction at other hospitals or an image classification model is trained and deployed on pictures with slightly different orientations. These applications require that a model generalizes well to new data distributions in addition to the training distribution, unlike standard machine learning tasks that focus on minimizing same-distribution error. One approach for generalizing to unseen domains is to learn representations that remain invariant across domains, by enforcing that their distributions stay the same across domains, either marginally Muandet et al. (2013); Li et al. (2018b); Ganin et al. (2016) or conditional on the class label Li et al. (2018c); Ghifary et al. (2016); Li et al. (2018d). Other methods frame invariance in terms of accuracy of the classifier on different domains Arjovsky et al. (2019); Peters et al. (2016). However, it is unclear how to evaluate these different invariance conditions for their applicability, e.g., under class imbalance or other differences between domains.
To this end, we introduce a formal causal framework for the domain generalization task that allows an easy and coherent characterization of the invariance conditions. Specifically, we construct a model for the data generation process that assumes each input is constructed from a mix of inherent (causal) and domain-dependent (non-causal) features. Building on prior work Gong et al. (2016); Heinze-Deml and Meinshausen (2019), we consider domain as a special intervention that changes the non-causal features of an input, and posit that an ideal classifier should be based only on the causal features. Using this model, we show that methods based on enforcing same distribution of representations across domains are inconsistent, confirming claims from past work Arjovsky et al. (2019); Zhao et al. (2019); Johansson et al. (2019). Furthermore,we show that methods that enforce the same distribution conditional on class label are also insufficient, unless additional assumptions are made.
The same causal model also provides us with the right condition that an invariant representation should satisfy. Our invariance condition depends on a special object variable that defines a collection of inputs that share the same causal features. For example, images of the same person from different viewpoints correspond to a single object, and so do images of the same thing in different rotations, color or background. We show that the correct invariance condition is that the learnt representation be the same for each object across domains. When the object variable is available (e.g., in self-collected data or by using dataset augmentation), we propose a matching regularizer for domain generalization that minimizes the distance between representations of the same object across domains.
In practice, however, the underlying objects are not always known or replicated across domains. We therefore propose an approximation of the above invariant condition that uses class as a proxy, under the assumption that inputs from the same class have more similar causal features than those from different classes. Our algorithm, MatchDG has two phases. First, it constructs a representation such that inputs sharing the same causal features are closer to one another, and matches pairs of inputs that are most similar. Second, it uses these learnt matches as a regularizer when building the classifier.
We find that MatchDG outperforms state-of-the-art methods for out-of-domain accuracy on rotated MNIST and Fashion-MNIST, and the PACS datasets. In addition, for the rotated MNIST and Fashion-MNIST where the ground-truth objects are known, MatchDG learns to makes the representation more similar to their ground-truth matches (more than 50% overlap for top-10 matches), even though the method does not have access to them. Our simple class-based approximation of the invariance condition also beats baselines, indicating the importance of enforcing the correct invariance condition.
Contributions. To summarize, our contributions include:
Invariance Condition. We propose an object-invariant condition for learning a common representation for domain generalization and justify its correctness in contrast to previous approaches.
Class-conditional approximation. When object information is not available, we provide an approximation that leads to a simple algorithm using contrastive loss formulation.
MatchDG Algorithm. We provide a novel two-phase learning algorithm that provides state-of-the-art results on domain generalization datasets like PACS and rotated MNIST and Fashion-MNIST.
2 A causal view of the domain generalization problem
We consider a classification task where the learning algorithm has access to i.i.d. data from domains, where where and is a set of domains. Each training input
is sampled from an unknown probability distribution.The task of domain generalization is to learn a single classifier that generalizes well to data from unseen domains and to new data from the same domains Shankar et al. (2018). Thus, the optimum classifier can be written as: , where over .
However, we only have access to
domains during training. The plug-in estimator replacesby following the Empirical Risk Minimization (ERM) principle.
where is a training data of size .
The ERM estimator from (1) learns the true , as the set of training domains and number of data samples .
However, when , ERM can overfit to the training domains. To avoid overfitting, a popular technique is to learn a common feature representation across all training domains that can be subsequently used to train a classifier Muandet et al. (2013); Li et al. (2018b); Ganin et al. (2016); Li et al. (2018c); Ghifary et al. (2016); Li et al. (2018d); Tzeng et al. (2015); Albuquerque et al. (2019); Ghifary et al. (2015); Hu et al. (2019). Due to a lack of formal definition of the problem, different learning objectives have been proposed, such as minimizing the distributional distance between learnt feature representations from different domains Muandet et al. (2013); Li et al. (2018b); Ganin et al. (2016) or minimizing the distance between class-conditional feature representations Li et al. (2018c); Ghifary et al. (2016); Li et al. (2018d). We provide a causal framework that provides a correct invariant condition needed for domain generalization.
2.1 Data-generating process
Consider a classification task of detecting the type of item or screening an image for a medical condition. To build a classifier, a train set is generated by taking photos. Due to human variability or by design (to take advantage of data augmentation), a photo may be taken at different angles and sometimes the same object may have multiple photos taken. Thus, the process yields different number of images for each class, sometimes with multiple images for the same object. In this example, the domain generalization task is to build a classifier that is robust to different views of any new object.
Figure 0(a) shows a structural causal model (SCM) that describes the data-generating process. Here each view can be considered as a different domain , the label for item type or medical condition as the class , and the image pixels as the features . Photos of the same item or the same person correspond to a common object variable Heinze-Deml and Meinshausen (2019), denoted by . To create an image, the data-generating process first samples an object and view (domain) that may be correlated to each other, (shown with dashed arrows). The pixels in the photo are caused by both the object and the view, as shown by the two incoming arrows to . Separately, the object corresponds to high-level causal features that are common to any image of the same object, which in turn are used by humans to label the class .
The above example is typical of a domain generalization problem; a general SCM is shown in Figure 0(b). In general, the underlying object for each input may not be observed. Changing the domain can be seen as an intervention: for each observed , there are a set of (possibly unobserved) counterfactual inputs where , such that all correspond to the same object. Analogous to the causal features , we introduce a node for domain-dependent high-level features of the object . For completeness, we also show the true unobserved label of the object which led to its generation as . Like the object , may be correlated with the domain . Extending the model in Heinze-Deml and Meinshausen (2019), we allow that objects can be correlated with the domain conditioned on . As we shall see, considering the relationship of the Object node becomes the key piece for developing the invariant condition. We can write the following non-parametric equations corresponding to the SCM.
where , , , and are general non-parametric functions. The error is correlated with domain whereas , and are mutually independent error terms that are independent of all other variables. Thus, noise in the class label is independent of domain. Since is common to all inputs of the same object, is a deterministic function of . In addition to these equations, the SCM provides conditional-independence conditions that all data distributions must satisfy, through the concept of d-separation and the perfect map assumption Pearl (2009).
d-separation (Pearl (2009)): Let A,B,C be the three non-intersecting subsets of nodes in a causal graph . For any path between two nodes, a collider is a node where arrows of the path meet head-to-head. A path from to is said to be blocked by if either a non-collider on the path is in , or there is a collider on the path and neither the collider nor its descendants are in .
If all paths from to are blocked, then A is d-separated from B by C: .
2.2 Identifying the invariance condition
From Figure 0(b), is the node that causes . Further, by d-separation, if we condition on , then the class label is independent of domain, . Thus our goal is to learn as where . The ideal loss-minimizing function can be rewritten as (assuming is known):
Since is unobserved, this implies that we need to learn it too through a representation function . Together, leads to the desired classifer .
Conditional independencies from the SCM identify the correct learning goal for learning . By the d-separation criterion, we see that satisfies two conditions: 1) , 2) ; where refers to the object variable and refers to a domain. The first is an invariance condition: does not change with different domains for the same object. To enforce this, we stipulate that the average pairwise distance between for inputs across different domains for the same object should be zero, . Here is a matching function that is 1 for pairs of inputs across domains corresponding to the same object, and 0 otherwise.
However, just the above invariance will not work: we need the representation to be informative of the object too (otherwise even a constant minimizes the above loss). Therefore, the second condition stipulates that should be informative of the object, and hence about . To ensure informativeness, we can add the standard classification loss, leading to the constrained optimization,
where . Here represents the composition
. For example, a neural network withas its r layer, and being the rest of the layers. The proof for the following theorem is in Suppl. A.3.
For a finite number of domains , as the number of examples in each domain ,
that is directly caused by domain, and for P-admissible loss functions
1. The set of representations that satisfy the condition contains the optimal that minimizes the domain generalization loss in (2).
2. Further, assuming that for every high-level feature
that is directly caused by domain, and for P-admissible loss functionsMiller et al. (1993) whose minimization is the conditional expectation (e.g., loss or cross-entropy), a loss-minimizing classifier for the following loss is the true function , for some value of .
2.3 Comparison to prior work on learning common representations
Using our model, we now discuss three main representation learning objectives from past work: domain-invariant, class-conditional domain-invariant, and invariant-optimal-classifier representations.
Domain-invariant representations. The goal is to learn a representation such that its distribution is the same across domains Muandet et al. (2013); Li et al. (2018b); Ganin et al. (2016), assuming that the ideal representation is independent of domain, . Recent work Arjovsky et al. (2019); Zhao et al. (2019); Johansson et al. (2019) has argued that this condition fails when is correlated with . However, using d-separation on the SCM from Figure 0(b), we find that is not sufficient since blocks the path between and . Hence, domain-invariant methods require a stronger condition that both class label and actual objects sampled be independent of domain.
Class-conditional domain-invariant . As a better objective, class-conditional methods Li et al. (2018c); Ghifary et al. (2016); Li et al. (2018d); Gong et al. (2016) aim to obtain representations such that is the same across domains, through minimizing distribution divergence measures such as the maximum mean discrepancy Li et al. (2018c). However, even in the ideal case where we observe , d-separation on the SCM reveals that due to a path through . Thus, having the same distribution per class is not consistent with properties of .
The above discussion indicates that previous representation learning methods optimize an incorrect objective: even with infinite data across domains, they will not learn the true .
The conditions enforced by domain-invariant () or class-conditional domain-invariant () methods are not satisfied by the causal representation . Thus, without additional assumptions, the set of representations that satisfy any of these conditions does not contain , even as .
Invariant-optimal-classifier . Recent work Arjovsky et al. (2019); Ahuja et al. (2020) assumes that remains the same across domains and thus a single classifier over the optimal should be optimal for all domains. That is, which is also satisfied by in the SCM. However, enforcing this condition is difficult in practice and thus the resultant method is restricted to a linear classifier over . In the next section, we propose an alternative condition that supports any architecture, is simple to implement and is also consistent with the SCM.
3 MatchDG: Proposed algorithm
When object information is available, such as in self-collected datasets or when using dataset augmentations, Eq. (4) provided a loss objective to build a classifer based on causal features. In general, however object information is not available. Further, in many datasets there may not be a perfect “counterfactual” match based on same object for an input in other domains. Therefore, we now propose a method that aims to learn when no object information is available. We first provide a invariance condition based on observed data that is consistent with the conditional independencies of . We then provide a two-phase contrastive learning method to learn such an .
3.1 An invariant loss consistent with properties of
The object-invariant condition from Section 2.2 can be interpreted as matching pairs of inputs from different domains that share the same . Thus, to approximate it, our goal is to learn a matching such that pairs having have low difference in and . One simple way is to use the class label and match every input to a random input from the same class. This leads to the following random match regularizer for learning a classifier:
where randomly matches pairs from the same class. Assuming of inputs from the same class is bounded by , we show that this simple matching strategy does include as a possible solution.
Assume that training domains are diverse such that for any two same-class inputs and from domains and , where is any high-level feature that is caused directly by domain. Further, assume that the distance over between same-class inputs from different domains is bounded: and . Then for some , a loss-minimizing classifier for the loss from (5) is the true function , given a P-admissible loss function and a finite number of domains with in each domain.
The proof substitutes in the match condition and uses Lagrange multipliers, detailed in Suppl. A.5 Compared to class-conditional domain-invariant condition, a key distinction is that we are enforcing the difference in individual
representations for inputs from the same class to be low, not just that they have the same distribution. That is, we are also minimizing the variance of the class-conditional distribution. A variation of this loss is used as a contrastive regularizer inMotiian et al. (2017); Dou et al. (2019). However, Theorem 2 cannot provide any guarantee that will be returned by the optimization. The key parameter is . If a dataset has low (as compared to distance between inputs of different classes), then there is a high chance of learning a good representation that is close to (if , we obtain perfect match). But as increases, the matching condition loses any discriminative power. Therefore we need to learn a matching such that is minimized.
3.2 MatchDG: Two-phase algorithm with learnt matches
To learn such a , we utilize recent work from unsupervised contrastive learning Chen et al. (2020); He et al. (2019) and adapt it for domain generalization. Our method works on the assumption that two inputs from the same class have more similar causal features than inputs from a different classes. Thus, if we can build a representation where inputs from the same class are comparatively closer, then a matching built on the closest pairs in this representation are more likely to have lower than a simple class-based random match. Specifically, we optimize a contrastive representation learning loss that minimizes distance between same-class inputs from different domains in comparison to inputs from different classes across domains. Adapting the contrastive loss for a single domain Chen et al. (2020), we consider positive matches as two inputs with the same class but different domains, and negative matches as pairs with different classes. For every positive match pair , we propose the following loss where
is a hyperparameter and
is the cosine similarity.
Our key insight is that matches can be updated during training. Rather than the standard contrastive loss where matches are pre-decided, we can start training with a random match based on class and then after every epochs, update the matches based on nearest same-class pairs in representation space, and iterate until convergence. Under assumption that inputs of same class are closer in causal features, optimizing for the initial random matches should lead to a representation wherein similarity correlates more to similarity in causal features. This completes Phase I of the algorithm. In Phase 2, we utilize the final representation to compute a new match function based on closest same-class pairs and then apply (5) to obtain a classifier regularized on those matches. In Suppl. C.5, we compare the gains due to the proposed iterative matching versus standard contrastive training.
To implement MatchDG, we build a data matrix containing
positive matches for each input and then sample mini-batches from this matrix. The last layer of the contrastive loss network is considered as the learnt representation. Algorithm1 provides an overview; details are in Suppl. B.1. We implement MatchDG as a 2-phase method, unlike previous methods Motiian et al. (2017); Dou et al. (2019) that employed class-based contrastive loss as a regularizer with ERM. This is to avoid the classification loss interfering with the goal of learning an invariant representation across domains (e.g., in datasets where one of the domains has many more samples than others). Therefore, we first learn the match function using only the contrastive loss. Our results in Suppl. C.3 show that the two-phase method provides better overlap with ground-truth perfect matches than optimizing classification and matching simultaneously.
We evaluate MatchDG for classification accuracy and the quality of its generated matches. We compare MatchDG to 3 other training methods: 1) ERM: Standard empirical risk minimization, 2) ERM-RandMatch that implements the loss from Eq. (5), 3) ERM-PerfMatch that has access to object information and uses the loss from Eq. (4) (not for PACS). We use the cross-entropy loss for and distance for in Eq.(4), (5). Details of implementation are in Suppl. B.1. We also evaluate against state-of-the-art methods such as MASF Dou et al. (2019), CSD Piratla et al. (2020) and JiGen Carlucci et al. (2019). All the numbers are averaged over
runs with standard deviation in brackets. We use three domain generalization datasets from prior work: Rotated MNISTPiratla et al. (2020) and Fashion-MNIST Piratla et al. (2020), and PACS Li et al. (2017).
Rotated MNIST & Fashion-MNIST. The datasets contain rotations of grayscale MNIST handwritten digits and fashion article images from to with an interval of Ghifary et al. (2015), where each rotation angle represents a domain. Since different domains’ images are generated from the same base image (object), there exist perfect matches across domains. The task is to predict the correct class label out of classes on a new rotation domain. Following past work Piratla et al. (2020), we report accuracy on and together as the test domain and the rest as the train domains; since these test domains are the hardest to generalize to (being extreme angles not covered in the span of train domains). We use and samples Piratla et al. (2020) from each domain for rotated MNIST and Fashion-MNIST during training, and train models using a 2-layer Motiian et al. (2017) and Resnet-18 architecture Piratla et al. (2020) .
PACS. This dataset contains total images from four domains: Photos (P), Art painting (A), Cartoon (C) and Sketch (S). The task is to classify objects over different classes. Following Dou et al. (2019); Carlucci et al. (2019), we train models with each of the domain as the target using Resnet-18 and Alexnet.
4.1 Results: Rotated MNIST (rotMNIST) and Fashion MNIST (rotFashionMNIST)
Under different training domains, Table 1 shows classification accuracy on rotMNIST and rotFashionMNIST for test domains and using Resnet-18 architecture. On both datasets, MatchDG outperforms all baselines except ERM-PerfMatch which has oracle access to the perfect matches. MatchDG’s accuracy lies between ERM-RandMatch and ERM-PerfMatch, indicating the benefit of learning a matching function. As the number of training domains decrease, the gap between MatchDG and baselines is highlighted: with 3 source domains for rotFashionMNIST, MatchDG achieves accuracy of whereas the next best method ERM-RandMatch achieves . Interestingly, ERM-RandMatch also performs better than prior work (MASF and CSD) on rotFashionMNIST dataset, suggesting the power of the simple matching condition proposed in this paper. Note that our baseline ERM provides better accuracy than MASF and CSD on rotMNIST dataset for training with domains but not when they are reduced.
|96.5 (0.15)||93 (0.2)||94.7 (0.2)||97.5 (0.17)||97.5 (0.36)||98.5 (0.08)|
|30, 45, 60||80.6 (2.9)||69.4 (1.32)||89.1 (0.004)||82.8 (2.3)||88.9 (2.01)||93.6 (0.53)|
|30, 45||64.0 (2.28)||60.8 (1.53)||77.2 (0.04)||69.7 (2.93)||79.3 (4.2)||84.2 (2.33)|
|78.5 (1.15)||72.4 (2.9)||78.9 (0.7)||80.5 (0.97)||83.5 (1.16)||85.1 (0.97)|
|30, 45, 60||33.9 (1.04)||25.7 (1.73)||27.8 (0.01)||35.5 (1.07)||51.7 (2.08)||61.04 (1.33)|
|30, 45||21.85 (0.93)||20.8 (1.26)||20.2 (0.01)||23.9 (0.93)||36.6 (2.17)||42.0 (2.42)|
|16.4 (1.15)||47.3 (1.17)||28.3 (1.25)|
|26.8 (2.49)||58.9 (3.31)||21.4 (2.37)|
|37.5 (0.23)||70.6 (4.99)||12.8 (2.23)|
|2.8 (0.23)||13.5 (0.49)||197.2 (5.89)|
|24.9 (3.43)||51.9 (5.63)||76.1 (14.3)|
|58.3 (3.71)||85.3 (3.78)||10.5 (2.81)|
|97.5 (0.17)||80.5 (0.97)|
|Approx 25%||97.7 (0.33)||82.5 (1.14)|
|Approx 50%||97.7 (0.29)||83.1 (1.19)|
|Approx 75%||98.2 (0.16)||83.9 (0.81)|
|98.5 (0.08)||85.1 (0.97)|
We also evaluate on a simpler 2-layer network to compare MatchDG to other prior works Ganin et al. (2016); Shankar et al. (2018); Ghifary et al. (2015); Motiian et al. (2017); Goodfellow et al. (2014), and on a more difficult setting when the test domain data is sampled from different base objects than the training domain. Results are in Suppl. C.1, C.2.
Why MatchDG works? We compare the matches returned by MatchDG Phase I (on Resnet-18 network) to the ground-truth perfect matches and find that it has significantly higher overlap than matching based on ERM loss (Table 2). We report three metrics on the representation learnt at the end of Phase I: percentage of MatchDG matches that are perfect matches, percentage of inputs for which the perfect match is within the top-10 ranked MatchDG matches, and mean rank of perfect matches measured by distance over the MatchDG representation.
On all three metrics, MatchDG finds a representation whose matches are more consistent with ground-truth perfect matches. For both rotMNISTand rotFashionMNIST datasets, more than 50% of the inputs have their perfect match within top-10 ranked matches based on the representation learnt by MatchDG Phase I. About 25% of all matches learnt by MatchDG are perfect matches. For comparison, we also show metrics for an (oracle) MatchDG method that is initialized with perfect matches: it achieves even better overall and Top-10 values, indicating the importance of knowing ground-truth matches that can be useful in data augmentation settings. Similar results for MatchDG Phase 2 are in Suppl. C.3. Mean rank for rotFashionMNIST may be higher because of the larger sample size per domain; metrics for training with samples are in Suppl. C.4.
Finally, to see how the overlap with perfect matches affects classification accuracy, we simulate random matches with 25%, 50% and 75% overlap with perfect matches (Table 3). We find that accuracy increases as the fraction of perfect matches increase. Interestingly, the accuracy for 25% overlap roughly predicts the reported accuracy for MatchDG (which had about 25% overlap). These results confirm that in addition to high accuracy, Phase I of MatchDG allows for learning a representation where inputs with the same causal features are closer.
4.2 PACS dataset: Accuracy and t-SNE plots
On the PACS dataset too, our matching-based methods outperform prior work on classification accuracy averaged over all domains (Table 4). In particular, on the relatively difficult domains of Sketch and Cartoon, MatchDG and ERM-RandMatch achieve the highest accuracies: MatchDG obtains on the Sketch domain, an improvement of from prior state-of-the-art. Results with AlexNet network and comparisons to more prior work Li et al. (2018d); Arjovsky et al. (2019); Ghifary et al. (2015); Dou et al. (2019); Carlucci et al. (2019); Li et al. (2017, 2018a); Balaji et al. (2018); Li et al. (2019) are in Suppl. D. Beyond accuracy, we investigate the quality of representations learnt by MatchDG using t-SNE Maaten and Hinton (2008) in Figure 2. Comparing the Phase I models for the easiest (Photo) and hardest (Sketch) unseen domains (Figs. 2a,b), we find that MatchDG achieves a higher overlap between train and test domains for Photo than Sketch, highlighting the difficulty of generalizing to the Sketch domain, even as classes are well-separated in the training domains for both models (Figs. 2c,d).
|Photo||95.37 (0.37)||94.99 (0.09)||96.03||95.45||95.59 (0.07)||95.7 (0.52)|
|Art Painting||75.79 (1.38)||80.29 (0.18)||79.42||79.79||78.58 (0.56)||77.9 (0.61)|
|Cartoon||77.82 (0.39)||77.17 (0.08)||75.25||75.04||80.02 (1.01)||78.8 (0.29)|
|Sketch||69.75 (0.58)||71.69 (0.22)||71.35||72.46||76.03 (1.34)||76.3 (0.92)|
5 Related Work
There are four main approaches for the domain generalization task: learning a common representation, dataset augmentation, meta-learning and sharing common parameters.
Learning common representation. To learn a generalizable classifier, several methods enforce same distribution of across domains marginally or conditional on class label, using divergence measures such as maximum mean discrepancy Muandet et al. (2013); Li et al. (2018c), adversarial training with a domain discriminator Li et al. (2018b); Ganin et al. (2016); Li et al. (2018d), use discriminant analysis Ghifary et al. (2016); Hu et al. (2019), and other techniques Ghifary et al. (2015). In Section 2.3, we identified limitations of the above methods. A more recent line of work Arjovsky et al. (2019); Ahuja et al. (2020) enforces domain-invariance of the optimal that we compare to in the evaluation. There is work on use of causal reasoning for domain adaptation Gong et al. (2016); Heinze-Deml and Meinshausen (2019); Magliacane et al. (2018); Rojas-Carulla et al. (2018) that assumes direction and other work Arjovsky et al. (2019); Peters et al. (2016) on connecting causality that assumes . Our SCM model unites these streams by introducing and labelled and develop an invariance condition for domain generalization that is valid under both interpretations. Perhaps the closest to our work is by Heinze-Deml and Meinshausen Heinze-Deml and Meinshausen (2019) who use the object concept in generation of input for a single domain but assume that objects are observed. We provide an algorithm that does not depend on observed objects. In doing so, we provide theoretical justification for the past uses of contrastive loss in domain generalization based on the class label Motiian et al. (2017); Dou et al. (2019) or using augmented data Tian et al. (2019).
Meta-learning. Meta-learning can be applied to domain generalization, by creating meta-train and meta-test domains within each mini-batch and ensuring that the weight updates perform well on the meta-test domains Dou et al. (2019); Li et al. (2018a); Balaji et al. (2018). While we showed that a contrastive training only can achieve promising results, combining meta-learning with our approach is an interesting future direction.
Dataset augmentation.: The data augmentation methods create more out-of-domain samples, from distributions within a bounded distance Volpi et al. (2018) or on a continuous space of domain interventions Shankar et al. (2018).
Parameter Decomposition. Finally there is work that focuses on identifying common model parameters across domains Piratla et al. (2020); Li et al. (2017); Daumé III et al. (2010), rather than a common input representation. We compared our work against one such recent method based on low-rank decomposition (CSD) Piratla et al. (2020).
We presented a causal interpretation of domain generalization and used it to derive a method that matches representations of input pairs that share causal features. We find that combining ERM with a simple invariance condition performs better on benchmarks than prior work, and hope to investigate matching-based methods with domain-dependent noise on class label and object in future work.
We would like to thank Adith Swaminathan, Aditya Nori, Emre Kiciman, Praneeth Netrapalli, Tobias Schnabel, and Vineeth Balasubramanian who provided us valuable feedback on this work. We also thank Vihari Piratla who helped us with reproducing the CSD method and other baselines.
-  (2020) Invariant risk minimization games. arXiv preprint arXiv:2002.04692. Cited by: §2.3, §5.
-  (2019) Adversarial target-invariant representation learning for domain generalization. arXiv preprint arXiv:1911.00804. Cited by: §B.4, Table 11, Appendix D, §2.
-  (2019) Invariant risk minimization. arXiv preprint arXiv:1907.02893. Cited by: Table 11, §1, §1, §2.3, §2.3, §4.2, §5.
-  (2018) Metareg: towards domain generalization using meta-regularization. In Advances in Neural Information Processing Systems, pp. 998–1008. Cited by: Table 11, §4.2, §5.
-  (2019) Domain generalization by solving jigsaw puzzles. In , pp. 2229–2238. Cited by: §B.3, Table 11, §4.2, §4, §4.
-  (2020) A simple framework for contrastive learning of visual representations. arXiv preprint arXiv:2002.05709. Cited by: §3.2.
Frustratingly easy semi-supervised domain adaptation.
Proceedings of the 2010 Workshop on Domain Adaptation for Natural Language Processing, pp. 53–59. Cited by: §5.
-  (2019) Domain generalization via model-agnostic learning of semantic features. In Advances in Neural Information Processing Systems, pp. 6447–6458. Cited by: §B.3, §B.4, Table 11, Appendix D, Appendix D, §3.1, §3.2, §4.2, Table 1, §4, §4, §5, §5.
-  (2016) Domain-adversarial training of neural networks. The Journal of Machine Learning Research 17 (1), pp. 2096–2030. Cited by: §C.1, Table 6, §1, §2.3, §2, §4.1, §5.
-  (2016) Scatter component analysis: a unified framework for domain adaptation and domain generalization. IEEE transactions on pattern analysis and machine intelligence 39 (7), pp. 1414–1430. Cited by: §1, §2.3, §2, §5.
Domain generalization for object recognition with multi-task autoencoders. In Proceedings of the IEEE international conference on computer vision, pp. 2551–2559. Cited by: §C.1, Table 6, Table 11, Appendix D, §2, §4.1, §4.2, §4, §5.
-  (2016) Domain adaptation with conditional transferable components. In International conference on machine learning, pp. 2839–2848. Cited by: §1, §2.3, §5.
-  (2014) Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572. Cited by: §C.1, Table 6, §4.1.
-  (2019) Momentum contrast for unsupervised visual representation learning. arXiv preprint arXiv:1911.05722. Cited by: §3.2.
-  (2019) Conditional variance penalties and domain shift robustness. arXiv preprint arXiv:1710.11469. Cited by: §1, §2.1, §2.1, §5.
Domain generalization via multidomain discriminant analysis.
Uncertainty in artificial intelligence: proceedings of the… conference. Conference on Uncertainty in Artificial Intelligence, Vol. 35. Cited by: §2, §5.
-  (2019) Support and invertibility in domain-invariant representations. In The 22nd International Conference on Artificial Intelligence and Statistics, pp. 527–536. Cited by: §1, §2.3.
-  (2017) Deeper, broader and artier domain generalization. In Proceedings of the IEEE international conference on computer vision, pp. 5542–5550. Cited by: Table 11, Appendix D, §4.2, §4, §5.
-  (2018) Learning to generalize: meta-learning for domain generalization. In Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §B.3, Table 11, Appendix D, §4.2, §5.
-  (2019) Episodic training for domain generalization. In Proceedings of the IEEE International Conference on Computer Vision, pp. 1446–1455. Cited by: Table 11, Appendix D, §4.2.
-  (2018) Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5400–5409. Cited by: §1, §2.3, §2, §5.
-  (2018) Domain generalization via conditional invariant representations. In Thirty-Second AAAI Conference on Artificial Intelligence, Cited by: §1, §2.3, §2, §5.
-  (2018) Deep domain generalization via conditional invariant adversarial networks. In Proceedings of the European Conference on Computer Vision (ECCV), pp. 624–639. Cited by: Table 11, Appendix D, §1, §2.3, §2, §4.2, §5.
-  (2008) Visualizing data using t-sne. Journal of machine learning research 9 (Nov), pp. 2579–2605. Cited by: §4.2.
-  (2018) Domain adaptation by using causal inference to predict invariant conditional distributions. In Advances in Neural Information Processing Systems, pp. 10846–10856. Cited by: §5.
On loss functions which minimize to conditional expected values and posterior probabilities. IEEE Transactions on Information Theory 39 (4), pp. 1404–1408. Cited by: Theorem 1.
-  (2017) Unified deep supervised domain adaptation and generalization. In Proceedings of the IEEE International Conference on Computer Vision, pp. 5715–5725. Cited by: §C.1, Table 6, §3.1, §3.2, §4.1, §4, §5.
-  (2013) Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. Cited by: §1, §2.3, §2, §5.
-  (2009) Causality. Cambridge university press. Cited by: §2.1, Definition 1.
Causal inference by using invariant prediction: identification and confidence intervals. Journal of the Royal Statistical Society: Series B (Statistical Methodology) 78 (5), pp. 947–1012. Cited by: §1, §5.
-  (2020) Efficient domain generalization via common-specific low-rank decomposition. arXiv preprint arXiv:2003.12815. Cited by: §B.3, §B.4, §B.4, Table 1, §4, §4, §5.
Invariant models for causal transfer learning. The Journal of Machine Learning Research 19 (1), pp. 1309–1342. Cited by: §5.
-  (2018) Generalizing across domains via cross-gradient training. In International Conference on Learning Representations, Cited by: §B.4, §C.1, Table 6, §2, §4.1, §5.
-  (2019) Contrastive multiview coding. arXiv preprint arXiv:1906.05849. Cited by: §5.
-  (2015) Simultaneous deep transfer across domains and tasks. In Proceedings of the IEEE International Conference on Computer Vision, pp. 4068–4076. Cited by: §2.
-  (2018) Generalizing to unseen domains via adversarial data augmentation. In Advances in Neural Information Processing Systems, pp. 5334–5344. Cited by: §5.
-  (2019) On learning invariant representation for domain adaptation. arXiv preprint arXiv:1901.09453. Cited by: §1, §2.3.
Appendix A Theory and Proofs
We first expand on the d-separation definition, providing a few examples that illustrate conditional independence implications of specific graph structures in Figure 3. We use these three conditions for all the proofs below.
a.2 Proof of Proposition 1
a.3 Proof of Theorem 1
CLAIM 1. The matching condition can be written as:
where for pairs of inputs and from two different domains and that correspond to the same object. The distance metric is non-negative, so the optimal is when is zero.
As in the SCM from Figure 0(b), let represent a feature vector such that it is generated based only on the object
represent a feature vector such that it is generated based only on the objectand that it leads to the optimal classifier in (2). From Sections 2.1 and 2.2, we know that and that . Thus, is the same for inputs from the same object and we can write:
Hence, leads to zero regularizer term and is one of the optimal minimizers for .
CLAIM 2. Further, we show that any other optimal is either a function of or a constant for all inputs. We prove by contradiction.
Let represent the set of unobserved high-level features that are generated based on both the object and the domain . From the SCM from Figure 0(b), a feature vector is independent of given the object, , and . Further, let there be an optimal for such that it depends on some (and is not trivially a constant function). Since is optimal, for all such that , where inputs and correspond to the same object.
Let us assume that there exists at least one object for which the effect of domain is stochastic. That is, due to domain-dependent variation, . for some and . Now consider a pair of inputs and from the same object such that , and their corresponding representations are and . Due to domain-dependent variation, with non-zero probability, the high-level features are not the same for these two input data points, . Since is a deterministic function of that is not independent of , if an input has a different , its value of will also be different. Thus, with non-zero probability, we obtain that , unless the effect of is a constant function. Hence, a contradiction and optimal cannot depend on any that are generated based on the domain.
Therefore, an optimal solution to can only depend on . However, any function of is optimal, including trivial functions like the constant function (that will have low accuracy). Below we show that using the ERM term in (4) ensures that the optimal solution contains only those functions of that also maximize accuracy.
Using (3), the empirical optimizer function can be written as (where we scale the loss by a constant , the total number of training data points):
where denotes all functions of that are optimal for (9), and the last equality is because can be written as . Since we assume that is a P-admissible loss function, its minimizer is the conditional expected value. Thus, for any domain , . Further, by d-separation, . Therefore, . The above equation indicates that the loss minimizer function on any domain is independent of the domain. Thus, for the training domains, we can write:
Now (12) can be rewritten as,
From the equation above, the loss for can be considered as a weighted sum of the average loss on each training domain where the weights are all positive. Since minimizes the average loss on each domain as , it will also minimize the overall weighted loss for all values of the weights. Therefore, for any dataset over domains in , is the optimal function that minimizes the overall loss.
Moreover, we can also write as:
Finally, using a Lagrangian multiplier, minimizing the following soft constraint loss is equivalent to minimizing (11), for some value of .
The result follows. ∎
Comment on Theorem 1. In the case where the effect of a domain is also deterministic, it is possible that (e.g., in artificially created domains like Rotated-MNIST where every object is rotated by the exact same amount in each domain). In that case Theorem 1 does not apply and it is possible to learn a representation that depends on and still minimizes to attain . For example, with two training domains on Rotated-MNIST dataset (, ), it is possible to learn a representation that simply memorizes to “un-rotate ” the angle back to . Such a representation will fail to generalize to domains with different rotation angles, but nonetheless minimizes by attaining the exact same representation for each object.
In practice, we conjecture that such undesirable are avoided by model-size regularization during training. As the number of domains increase, it may be simpler to learn a single transformation (representation) based on (and independent of features like angle) than learn separate angle-wise transformations for each train domain.
a.4 Proof of Corollary 1
As in the SCM from Figure 0(b), let represent an unobserved high-level feature vector such that it is generated based only on the object and that it leads to the optimal classifier in (2). From Sections 2.1 and 2.2, we know that and that . Following a similar proof to Theorem 1 (Claim 1), we check whether satisfies the invariance conditions required by the two methods.
Domain-invariant: The required condition for a representation is that . But using the d-separation criteria on the SCM in Figure 0(b), we find that due to a path through Object .
Class-conditional domain-invariant: The required condition for a representation is that . However using the d-separation criteria on the SCM, we find that due to a path through Object that is not blocked by (nor by if it is observed).
Therefore, under the conditions proposed by these methods, or any function of is not an optimal solution without making any additional assumptions. Hence, even with infinite samples, a method optimizing for these conditions will not retrieve . ∎
a.5 Proof of Theorem 2
Since is a constant for a dataset, the loss from (5) can be written as,
where we added a scaling constant for the total number of matches. For some value of , the above optimization is equivalent to,
where is the maximum difference in between two inputs with the same class but different domains.
We first show that satisfies the constraint in (18), where is any feature vector that is generated based only on the object (), as defined in the proof of Theorem 1. Now since for any two inputs and with the same class and different domains, the constraint from (18) is satisfied by .
In addition, let be a feature vector that is generated based on both the object and a domain () and that is conditionally independent of , . Since we assume that and , such an cannot satisfy the constraint in (18).
Appendix B Evaluation and implementation details
In this section we describe implementation details for our proposed methods. We also discuss the evaluation protocol, including details about hyperparameters and cross-validation.
For the implementation of ERM-RandMatch in Eq. (5) , ERM-PerfMatch in Eq. (4); we use the cross-entropy loss for and distance for in Eq. (4, 5). For both methods, we consider the representation to be the last layer of the network. That is, we take to be identity function in Eq. (4, 5) for simplicity. We use SGD to optimize for the loss for all the datasets, with details about learning rate, epoch, batch size, etc. provided in the section B.3 ahead. For all the different methods, we sample batches from the data matrix consisting of data points matched across domains; hence we ensure an equal number of data points from each source domain in a batch.
When training with MatchDG, for Phase-1 the architecture is selected to be ResNet-18, with the final fully connected layer removed to learn representations of dimension 512. For MatchDG Phase 1, the network is always the same across different datasets and irrespective of the Phase-2 architecture used for the classification task. In our evaluation, the underlying architecture for Phase 2 is kept to be the same for ERM, ERM-RandMatch, ERM-PerfMatch; which is set to ResNet-18 (Table 1) and LeNet (Table 6) in case of MNIST and Fashion-MNIST; while for PACS it is set to ResNet-18 (Table 4) and AlexNet (Table 11).
b.1 MatchDG implementation details
The MatchDG algorithm proceeds in two phases.
Initialization: We construct matches of pairs of same-class data points from different domains. Hence, given each data point we randomly select another data point with the same class from another domain. The matching for each class across domains is done relative to a base domain; which is chosen by taking the domain that has the highest number of samples for that class. This is done to avoid missing out on data points when there is class imbalance across domains. Specifically, we iterate over classes and for each class, we match data points randomly across domains w.r.t a base domain for that class. This leads to matrix of size , where refers to the updated domain size ( sum of the size of base domain for all the classes ) and K refers to the total number of domains. We describe the two phases below:
Phase 1: We samples batches from the matched data matrix , where B is the batch size. For each data point in the batch, we minimize the contrastive loss from (6) by selecting its matched data points across domains as the positive matches and consider every data point with a different class label from to be a negative match.
After every epochs, we periodically update the matched data matrix by using the representations learnt by contrastive loss minimization. We follow the same procedure of selecting a base domain for each class, but instead of randomly matching data points across domains, we find the nearest neighbour for the data point in base domain among the data points in the other domains with the same class label based on the L2 Distance between their representations. At the end of Phase I, we update the matched data matrix based on L2 distance over the final representations learnt. We call these matches as the inferred matches.
Phase 2: We train using the loss from Eq. (5), but instead of random matches, we use the inferred matches generated from Phase 1 (ERM + Inferred Match). We train the network from scratch in Phase 2 and use the representations learnt in Phase 1 to only update the matched data matrix.
The updated data matrix based on representations learnt in Phase 1 may lead to many-to-one matches from the base domain to the other domains. This can lead to certain data points being excluded from the training batches. Therefore, we construct batches such that each batch consists of two parts. The first is sampled as in Phase 1 from the matched data matrix. The second part is sampled randomly from all train domains. Specifically, for each batch sampled from the matched data matrix, we sample an additional part of size with data points selected randomly across domains. The loss for the second part of the batch is simply ERM, along with ERM + InferredMatch Loss on the first part of the batch.
b.2 Metrics for evaluating quality of learnt matches
Here we describe the three metrics used for measuring overlap of the learnt matches with ground-truth “perfect” matches.
Overlap %: Percentage of matches (j, k) as per the perfect match strategy that are also consistent with the learnt match strategy .
Top-10 Overlap %: Percentage of matches (j, k) as per the perfect match strategy that are among the Top-10 matches for the data point j w.r.t the learnt match strategy i.e.
Mean Rank: For the matches (j, k) as per the perfect match strategy , compute the mean rank for the data point j w.r.t the learnt match strategy i.e.
b.3 HyperParameter Tuning
To select hyperparameters, prior works [8, 5, 19] use leave-one-domain-out validation, which means that the hyperparameters are tuned after looking at data from the unseen domain. Such a setup is against the premise of the domain generalization task that assumes that a model should have no access to the test domain.
Therefore, in this work, we construct a validation set using only the source domains and use it for hyperparamter tuning. In the case of PACS, we already have access to the validation indices for each domain and use them to construct a validation set based on the source domains. For MNIST and Fashion MNIST; we construct a validation set by sampling ( where refers to the data points per domain for MNIST () and Fashion-MNIST ()  ) data points and then rotating them by the required angle as per the source domain. Hence, the model does not have access to the data points from the target/test domains at the time of training and validation.
For our proposed methods ERM-PerfMatch and ERM-RandMatch, there is a single hyperparameter to be tuned.
MNIST and Fashion MNIST
In case of MNIST and Fashion MNIST, we tune for the approach ERM-PerfMatch, with access to complete source domains ( 15, 30, 45, 60, 75 ) and use the same hyper parameters for the other methods like MatchDG (Phase 2), ERM, ERM-RandMatch, and also for the evaluation with reduced number of domains in Table 1 and fraction of perfect matches in Table 3. The number of epochs are selected based on the convergence point for the validation dataset for each model. The details of the hyper parameters for MNIST and Fashion-MNIST are provided below: ( refers to the total number of source domains)
For PACS dataset, we tune a different for different source domains using the validation mechanism described above.
ResNet-18: (Table 4). The batch size is ( refers to the total number of source domains) for all the models and learning rate is for Photo as the target domain and for all other target domains. The values for total epochs and for different training methods and target domain are given in Table 5.
ERM ERM-RandMatch MatchDG(Phase 2) Photo Total Epochs 15 15 20 0 0.1 0.1 Art Painting Total Epochs 15 15 15 0 0.01 0.1 Cartoon Total Epochs 15 15 15 0 0.01 0.02 Sketch Total Epochs 20 20 20 0 0.1 0.1 Table 5: Hyper parameter tuning details on PACS with ResNet-18 architecture
Alexnet: (Table 11). The batch size is , learning rate is 0.001 and total epochs 35 for all the models and target domains. For ERM-RandMatch, we keep the same values for batch size, learning rate and total epochs as ERM; while the value for is taken as 1.0 for all the test domains, except for the target domain Photo, where we take as 0.1. For MatchDG, the value for is taken as 1.0 with the batch size as , learning rate as 0.001 and total epochs as 35 .
The MatchDG Phase-1 does not have the hyper parameter () as it minimizes only the contrastive loss; it however has the temperature () hyper parameter. The details are provided below:
MNIST and Fashion MNIST. : 0.5, batch size: 64, lr: 0.01, total epochs: 30
PACS. : 0.5, batch size: 128, lr: 0.01, total epochs: 15
b.4 Reproducing Results from Prior Work
MNIST and Fashion MNIST.
The results for MASF, CSD in Table 1 are taken from ; while the results for prior approaches in Table 6 are taken from . For ResNet-18 architecture (Table 1); the results for the reduced number of domains for CSD and MASF were computed using their code which is available online 111https://github.com/vihari/CSD 222https://github.com/biomedia-mira/masf. The MASF code was hardcoded to run for PACS dataset; which has 3 source domains that gets divided into 2 meta train and 1 meta test domain. Their code requires atleast 2 meta train domains; which leads to an issue for only 2 source domains (30, 45). In Table 1 when there are only 2 source domains; their code considers only 1 meta train domain. To resolve this issue; we create a copy of the 1 meta train domain and hence run MASF for source domains 30, 45 on MNIST.