Domain Generalization via Model-Agnostic Learning of Semantic Features
Generalization capability to unseen domains is crucial for machine learning models when deploying to real-world conditions. We investigate the challenging problem of domain generalization, i.e., training a model on multi-domain source data such that it can directly generalize to target domains with unknown statistics. We adopt a model-agnostic learning paradigm with gradient-based meta-train and meta-test procedures to expose the optimization to domain shift. Further, we introduce two complementary losses which explicitly regularize the semantic structure of the feature space. Globally, we align a derived soft confusion matrix to preserve general knowledge about inter-class relationships. Locally, we promote domain-independent class-specific cohesion and separation of sample features with a metric-learning component. The effectiveness of our method is demonstrated with new state-of-the-art results on two common object recognition benchmarks. Our method also shows consistent improvement on a medical image segmentation task.READ FULL TEXT VIEW PDF
Model generalization capacity at domain shift (e.g., various imaging
Generalizing knowledge to unseen domains, where data and labels are
The generalization capability of neural networks across domains is cruci...
Domain generalization aims to learn a prediction model on multi-domain s...
Deep learning models perform best when tested on target (test) data doma...
Federated learning allows distributed medical institutions to collaborat...
Generalization of machine learning models trained on a set of source dom...
Domain Generalization via Model-Agnostic Learning of Semantic Features
Machine learning methods have achieved remarkable success, under the assumption that training and test data are sampled from the same distribution. In real-world applications, this assumption is often violated as conditions for data acquisition may change, and a trained system may fail to produce accurate predictions for unseen data with domain shift. To tackle this issue, domain adaptation algorithms normally learn to align source and target data in a domain-invariant discriminative feature space (kumar2010co; ganin2016domain; hoffman2018cycada; long2016unsupervised; luo2018taking; saenko2010adapting; saito2018maximum; tzeng2015simultaneous; tzeng2017adversarial). These methods rely on access to a few labelled (kumar2010co; saenko2010adapting; tzeng2015simultaneous) or unlabelled (ganin2016domain; hoffman2018cycada; long2016unsupervised; luo2018taking; saito2018maximum; tzeng2017adversarial) data samples from the target distribution during training.
An arguably harder problem is domain generalization, which aims to train a model using multi-domain source data, such that it can directly generalize to new domains without need of retraining. This setting is very different to domain adaptation as no information about the new domains is available, a scenario that is encountered in real-world applications. In the field of healthcare, for example, medical images acquired at different sites can differ significantly in their data distribution, due to varying scanners, imaging protocols or patient cohorts. At deployment, each new hospital can be regarded as a new domain but it is impractical to collect data each time to adapt a trained system. Learning a model which directly generalizes to new clinical sites would be of great practical value.
Domain generalization is an active research area with a number of approaches being proposed. As no a priori
knowledge of the target distribution is available, the key question is how to guide the model learning to capture information which is discriminative for the specific task but insensitive to changes of domain-specific statistics. For computer vision applications, the aim is to capture general semantic features for object recognition. Previous work has demonstrated that this can be investigated through regularization of the feature space, e.g., by minimizing divergence between marginal distributions of data sources(muandet2013domain), or joint consideration of the class conditional distributions (li2018deep). li2018domain
use adversarial feature alignment via maximum mean discrepancy. Leveraging distance metrics of feature vectors is another method(hsu2017learning; motiian2017unified). Model-agnostic meta-learning (finn2017model) is a recent gradient-based method for fast adaptation of models to new conditions, e.g., a new task at few-shot learning. Meta-learning has been introduced to address domain generalization (balaji2018metareg; li2018learning; li2019feature), by adopting an episodic training paradigm, i.e., splitting the available source domains into meta-train and meta-test at each iteration, to simulate domain shift. Promising performance has been demonstrated by deriving the loss from a task error (li2018learning)
, a classifier regularizer(balaji2018metareg), or a predictive feature-critic module (li2019feature).
We introduce two complementary losses which explicitly regularize the semantic structure of the feature space via a model-agnostic episodic learning procedure. Our optimization objective encourages the model to learn semantically consistent features across training domains that may generalize better to unseen domains. Globally, we align a derived soft confusion matrix to preserve inter-class relationships. Locally, we use a metric-learning component to encourage domain-independent while class-specific cohesion and separation of sample features. The effectiveness of our approach is demonstrated with new state-of-the-art performance on two common object recognition benchmarks. Our method also shows consistent improvement on a medical image segmentation task. Code for our proposed method is available at: https://github.com/biomedia-mira/masf.
Domain adaptation is based on the central theme of bounding the target error by the source error plus a discrepancy metric between the target and the source (ben2010theory). This is practically performed by narrowing the domain shift between the target and source either in input space (hoffman2018cycada), feature space (kumar2010co; ganin2016domain; long2016unsupervised; saenko2010adapting; tzeng2017adversarial), or output space (luo2018taking; saito2018maximum; tsai2018learning), generally using maximum mean discrepancy (gretton2012optimal; sejdinovic2013equivalence) or adversarial learning (goodfellow2014generative). The success of methods operating on feature representations motivates us to optimize the semantic feature space for domain generalization in this paper.
Domain generalization aims to generalize models to unseen domains without knowledge about the target distribution during training. Different methods have been proposed for learning generalizable and transferable representations. A promising direction is to extract task-specific but domain-invariant features (ghifary2015domain; li2018domain; li2018deep; motiian2017unified; muandet2013domain). muandet2013domain propose a domain-invariant component analysis method with a kernel-based optimization algorithm to minimize the dissimilarity across domains. ghifary2015domain learn multi-task auto-encoders to extract invariant features which are robust to domain variations. li2018deep
consider the conditional distribution of label space over input space, and minimize discrepancy of a joint distribution.motiian2017unified use contrastive loss to guide samples from the same class being embedded nearby in latent space across data sources. li2018domain
extend adversarial autoencoders by imposing maximum mean discrepancy measure to align multi-domain distributions. Instead of harmonizing the feature space, others use low-rank parameterized CNNs(li2017deeper) or decompose network parameters to domain-specific/-invariant components (khosla2012undoing). Data augmentation strategies, such as gradient-based domain perturbation (shankar2018generalizing) or adversarially perturbed samples (volpi2018generalizing) demonstrate effectiveness for model generalization. A recent method with state-of-the-art performance is JiGen carlucci2019domain, which leverages self-supervised signals by solving jigsaw puzzles.
Meta-learning (a.k.a. learning to learn (schmidhuber1987evolutionary; Thrun:1998:LL:296635)) is a long standing topic exploring the training of a meta-learner that learns how to train particular models (finn2017model; li2016learning; nichol2018first; ravi2016optimization). Recently, gradient-based meta-learning methods (finn2017model; nichol2018first) have been successfully applied to few-shot learning, with a procedure purely leveraging gradient descent. The episodic training paradigm, originated from model-agnostic meta-learning (MAML) (finn2017model), has been introduced to address domain generalization (balaji2018metareg; li2018learning; li2019episodic; li2019feature). Epi-FCR (li2019episodic) alternates domain-specific feature extractors and classifiers across domains via episodic training, but without using inner gradient descent update. The method of MLDG (li2018learning)
closely follows the update rule of MAML, back-propagating the gradients from an ordinary task loss on meta-test data. A limitation is that using the task objective might be sub-optimal, as it is highly abstracted from the feature representations (only using class probabilities). Moreover, it may not well fit the scenario where target data are unavailable (as pointed out bybalaji2018metareg). A recent method, MetaReg (balaji2018metareg), learns a regularization function (e.g., weighted loss) particularly for the network’s classification layer, excluding the feature extractor. Instead, li2019feature propose a feature-critic network which learns an auxiliary meta loss (producing a non-negative scalar) depending on output of the feature extractor. Both (balaji2018metareg) and (li2019feature) lack notable guidance from semantics of feature space, which may contain crucial domain-independent ‘general knowledge’ for model generalization. Our method is orthogonal to previous work, proposing to enforce semantic features via global class alignment and local sample clustering, with losses explicitly derived in an episodic learning procedure.
In the following, we denote input and label spaces by and , the domains are different distributions on the joint space . Since domain generalization involves a common predictive task, the label space is shared by all domains. In each domain, samples are drawn from a dataset where is the number of labeled data points in the -th domain. The domain generalization (DG) setting further assumes the existence of domain-invariant patterns in the inputs (e.g. semantic features), which can be extracted to learn a label predictor that performs well across seen and unseen domains. Unlike domain adaptation, DG assumes no access to observations from or explicit knowledge about the target distribution.
In this work, we consider a classification model composed of a feature extractor, , where is a feature space (typically much lower-dimensional than ), and a task network, , where is the number of classes in . The final class predictions are given by , where .111For image segmentation, extracts feature maps and the task network is applied pixel-wise. The parameters are optimized with respect to a task-specific loss , e.g. cross-entropy: .
Although the minimization of may produce highly discriminative features , and hence an excellent predictor for data from the training domains, nothing in this process prevents the model from overfitting to the source domains and suffering from degradation on unseen test domains. We therefore propose to optimize the feature space such that its semantic structure is insensitive to different training domains, and generalize better to new unseen domains. Figure 1 gives an overview of our model-agnostic learning of semantic features (MASF), which we will detail in this section.
The key of our learning procedure is an episodic training scheme, originated from model-agnostic meta-learning (finn2017model), to expose the model optimization to distribution mismatch. In line with our goal of domain generalization, the model is trained on a sequence of simulated episodes with domain shift. Specifically, at each iteration, the available domains are randomly split into sets of meta-train and meta-test domains. The model is trained to semantically perform well on held-out after being optimized with one or more steps of gradient descent with domains. In our case, the feature extractor’s and task network’s parameters, and , are first updated from the task-specific supervised loss (e.g. cross-entropy for classification), computed on meta-train:
is a learning-rate hyperparameter. This results in a predictive modelwith improved task accuracy on the meta-train source domains, .
Once this optimized set of parameters has been obtained, we can apply a meta-learning step, aiming to enforce certain properties that we desire the model to exhibit on held-out domain . Crucially, the objective function quantifying these properties, , is computed based on the updated parameters, , and the gradients are computed towards the original parameters, . Intuitively, besides the task itself, the training procedure is learning how to generalize under domain shift. In other words, parameters are updated such that future updates with given source domains also improve the model regarding some generalizable aspects on unseen target domains.
In particular, we desire the feature space to encode semantically relevant properties: features from different domains should respect inter-class relationships, and they should be compactly clustered by class labels regardless of domains (cf. Alg. 1). In the remainder of this section we describe the design of our semantic meta-objective, , composed of a global class alignment term and a local sample clustering term, with weighting coefficients .
Relationships between class concepts exist in purely semantic space, independent of changes in the observation domain. In light of this, compared with individual hard label prediction, aligning class relationships can promote more transferable knowledge towards model generalization. This is also noted by tzeng2015simultaneous
in the context of domain adaptation, by aggregating the output probability distribution when fine-tuning the model on a few labelled target data. In contrast to their work, our goal is to structure the feature space itself to preserve learned class relationships on unseen data, by means of explicit regularization.
Specifically, we formulate this objective in a manner that imposes a global
layout of extracted features, such that the relative locations of features from different classes embody the inherent similarity in semantic structures. Inspired by knowledge distillation from neural networks(hinton2014distilling), we exploit what the model has learned about class ambiguities—in the form of per-class soft labels—and enforce them to be consistent between and domains. For each domain , we summarize the model’s current ‘concept’ of each class by computing the class-specific mean feature vectors :
where is the number of samples in domain labelled as class . The obtained conveys how samples from a particular class are generally represented. It is then forwarded to the task network , for computing soft label distributions with a ‘softened’ softmax at temperature (hinton2014distilling):
The collection of soft labels represents a kind of ‘soft confusion matrix’ associated with a particular domain, encoding the inter-class relationships learned by the model. Such relationships should be preserved as general semantics on meta-test after updating the classification model on meta-train (e.g., cartoon dogs are more easily misclassified as horses than as houses, which likely holds in unseen domains). Standard supervised training with focuses only on the dominant hard label prediction, there is no reason a priori for consistency of such inter-class alignment. We therefore propose to align the soft class confusion matrix between two domains and , by minimising their symmetrized Kullback–Leibler (KL) divergence, averaged over all classes:
where . Other symmetric divergences such as Jensen–Shannon (JS) could also be considered, although our preliminary experiments showed no significant difference with JS over symm. KL. Finally, the global class alignment loss, , is calculated as the average of over all pairs of available meta-train and meta-test domains, (cf. Alg. 1). The complexity of this computation is not problematic in practice, since the number of domains selected in a training mini-batch is limited (as with the form in MAML (finn2017model)), and in our experiments we took and .
In addition to promoting the alignment of class relationships across domains with as defined above, we further encourage robust semantic features that locally cluster according to class regardless of the domain. This is crucial, as neither of the class-prediction-based losses ( or ) ensure that features of samples in the same class will lie close to each other and away from those of different classes, a.k.a. feature compactness kamnitsas2018semi. If the model cannot project the inputs to the semantic feature clusters with domain-independent class-specific cohesion and separation, the predictions may suffer from ambiguous decision boundaries, and still be sensitive to unseen kinds of domain shift. We therefore propose a local regularization objective to boost robustness, by increasing the compactness of class-specific clusters while reducing their overlap. Note how this is complementary to the global class alignment of semantically structuring the relative locations among class clusters.
Our preliminary experiments revealed that applying such regularization explicitly onto the features may constrain the optimization for and too heavily, hurting generalization performance on unseen domain. We thus take a metric-learning approach, introducing an embedding network that operates on the extracted features, . This component represents a learnable distance function (chopra2005learning) between feature vectors (rather than between raw inputs):
The sample pairs are randomly drawn from all source domains , because we expect the updated will harmonize the semantic feature space of with that of , in terms of class-specific clustering regardless of domains. The computed embeddings, , can then be optimized with any suitable metric-learning loss to regularize the local sample clustering. Under mild domain shift, the contrastive loss (hadsell2006dimensionality) is a sensible choice, as it attempts to separately collapse each group of same-class exemplars to a distinct single point. It might however be over-restrictive for more extreme situations, wherein domains are related rather semantically, but with wildly distinct low-level statistics. For such cases, we propose instead to employ the triplet loss (schroff2015facenet).
Contrastive loss is computed for pairs of samples, attracting samples of the same class and repelling samples of different classes (hadsell2006dimensionality). Instead of pushing clusters apart to infinity, the repulsion range is bounded by a distance margin .
Our contrastive loss for a pair of samples is defined as:
The total loss for a training mini-batch, , is normally averaged over all pairs of samples. In cases where full enumeration is intractable—e.g. image segmentation, which would involve all pairs of pixels in all images—we can obtain an unbiased estimator of the loss by e.g. shuffling the samples and iterating over pairs with .
Triplet loss aims to make pairs of samples from the same class closer than pairs from different classes, by a certain margin (schroff2015facenet). Given one ‘anchor’ sample , one ‘positive’ sample (with ), and one ‘negative’ sample (with ), we compute their triplet loss as follows:
schroff2015facenet argue that judicious triplet selection is essential for good convergence, as many triplets may already satisfy this constraint and others may be too hard to contribute meaningfully to the learning process. Here we adopt their proposed online ‘semi-hard’ triplet mining strategy, and is the average over all selected triplet pairs.
We evaluate and compare our method on three datasets: 1) the classic VLCS domain generalization benchmark for image classification, 2) the recently introduced PACS benchmark for object recognition with challenging domain shift, 3) a real-world medical imaging task of tissue segmentation in brain MRI. Results with an in-depth analysis and ablation study are presented in the following.
VLCS (fang2013unbiased) is a classic benchmark for domain generalization, which includes images from four datasets: PASCAL VOC2007 (V) (everingham2010pascal), LabelMe (L) (russell2008labelme), Caltech (C) (fei2007learning), and SUN09 (S) (choi2010exploiting). The multi-class object recognition task includes five classes: bird, car, chair, dog and person. We follow previous work (carlucci2019domain; li2019episodic; motiian2017unified) of using the publicly available pre-extracted DeCAF features (4096-dimensional vector) for leave-one-domain-out validation with randomly dividing each domain into training and
test, inputting to two fully connected layers with output size of 1024 and 128 with ReLU activation. For our metric embedding(inputting the 128-dimensional vector), we use two fully connected layers with output size of 128 and 64. The triplet loss is adopted for computing , with coefficient , such that it is in a similar scale to and (). We use the Adam optimizer (kingma2015adam) with initialized to and exponentially decayed by every iterations. For the inner optimization to obtain , we clip the gradients by norm (threshold by ) to prevent them from exploding, since this step uses plain, non-adaptive gradient descent (with learning rate ). Note that, although performing gradient descent on involves second-order gradients on , their computation does not incur a substantial overhead in training time (finn2017model). We also employ an Adam optimizer for the meta-updates of with learning rate without decay. The batch size is for each source domain, with an Nvidia TITAN Xp 12 GB GPU. The metric-learning margin hyperparameter
Results. Table 1 shows the object recognition accuracies on different target domains. Our DeepAll baseline—i.e., merging all source domains and training
by standard supervised learning onwith the same hyperparameters—achieves an average accuracy of over four domains. Using our episodic training paradigm with regularizations on semantic feature space, we improve the performance to , setting the state-of-the-art accuracy on VLCS. We compare with eight different methods (cf. Section 2) which report previous best results on this benchmark. CCSA (motiian2017unified) combines contrastive loss together with ordinary cross-entropy without using episodic meta-update paradigm. Notably, our approach outperforms MLDG li2018learning, indicating that explicitly encouraging semantic properties in the feature space is superior to using a highly-abstracted task loss on meta-test.
The PACS dataset (li2017deeper) is a recent benchmark with more severe distribution shift between domains, making it more challenging than VLCS. It consists of four domains: art painting, cartoon, photo, sketch, with objects from seven classes: dog, elephant, giraffe, guitar, house, horse, person. Following practice in the literature (balaji2018metareg; carlucci2019domain; li2018learning; li2019episodic), we also use leave-one-domain-out cross-validation, i.e., training on three domains and testing on the remaining unseen one, and adopt an AlexNet (krizhevsky2012imagenet)
pre-trained on ImageNet(ILSVRC15). The metric embedding is connected to the last fully connected layer (i.e., fc7 layer with a 4096-dimesional vector), by stacking two fully connected layers with output size of 1024 and 256. For the , we also use the triplet loss with , particularly considering the severe domain shift. We initialize learning rates and clip inner gradients by norm. The batch size is 128 for each source domain.
Results. Table 2 summarizes the results of object recognition on PACS dataset with a comparison to previous work (noting that not all compared methods reported results on both VLCS and PACS). MLDG (li2018learning) and MetaReg (balaji2018metareg) employ episodic training with meta-learning, but from different angles in terms of the meta learner’s objective (li2018learning minimize task error, balaji2018metareg learn a classifier regularizer). The promising results for (balaji2018metareg; li2018learning; li2019episodic) indicate that exposing the training procedure to domain shift benefits model generalization to unseen domains. Our method further explicitly considers the semantic structure, regarding both global class alignment and local sample clustering, yielding improved accuracy. Across all domains, our method increases average accuracy by over the baseline. Note that current state-of-the-art JiGen (carlucci2019domain) improves over its own baseline. In addition, we observe an improvement of when the unseen domain is sketch, which has a distinct style and requires more general knowledge about semantic concepts.
Ablation analysis. We conduct an extensive study using PACS benchmark to investigate two key points: 1) the contribution of each component to our method’s performance, 2) how the semantic feature space is influenced by our proposed meta losses. First, we test all possible combinations of including the key components: episodic meta-learning simulating domain shift, global class alignment loss and local sample clustering loss. Accuracies averaged over three runs when using different combinations are given in Table 3. For example, first row corresponds to the DeepAll baseline with standard training by aggregating all source data. The fifth row is directly adding the losses on top of with standard optimization scheme, i.e., without splitting to meta-train and meta-test domains. From the ablation study, we observe that each component plays its own role in a complementary way. Specifically, the proposed losses that encourage semantic structure in feature space yield improvement over DeepAll, as well as over pure episodic training (the second row) that corresponds to our implementation of MLDG thus enabling straightforward comparison. By further leveraging the gradient-based update paradigm, performance is further improved across all settings.
We utilize t-SNE (maaten2008visualizing) to analzye the feature space learned with our proposed model and the DeepAll baseline (cf. Fig. 2). It appears that our MASF model yields a better separation of classes. We also note that the sketch domain is further apart from art painting and cartoon, although all three are source domains in this experiment, possibly explained by the unique characteristics of sketches. In Figure 3 (a), we plot the difference of feature distances between samples of negative pairs and positive pairs, i.e., . For the two magenta lines, respectively for MASF and DeepAll, sample pairs are drawn from different training source domains. We see that both distance margins naturally increase as training progresses. The shaded area highlights that MASF yeilds a higher distance margin between classes compared to DeepAll, indicating that sample clusters are better separated with MASF. Similarly, for the two blue lines, sample pairs are from the unseen target domain and a source domain (randomly selected at each iteration). As expected, the margin is not as large as between training domains, yet our method still presents a notably bigger margin than the baseline. In Figure 3 (b), we plot quantifying differences of average class posteriors between unseen target domain and a source domain during training. We observe that the semantic inter-class relationships, conveying general knowledge about a recognition task, would not naturally converge and generalize to the unseen domain without explicit guidance.
Deeper architectures. In the interest of providing stronger baseline results, we perform additional preliminary experiments using more up-to-date deep residual architectures he2016resnet with ResNet-18 and ResNet-50. Table 4 shows strong and consistent improvements of MASF over the DeepAll baseline in all PACS splits for both network architectures. This suggests our proposed algorithm is also beneficial for domain generalization with deeper feature extractors.
We evaluate our method on a real-world medical imaging task of brain tissue segmentation in T1-weighted MRI. Data was acquired from four clinical centers (denoted as Set-A/B/C/D). Domain shift occurs due to differences in scanners, acquisition protocols and many other factors, posing severe limitations for translating learning-based methods to clinical practice glocker2019multisite. Figure 4 shows example images and intensity histograms. We adapt MASF for the segmentation of four classes: background, grey matter (GM), white matter (WM), cerebrospinal fluid (CSF). We employ a U-Net (ronneberger2015u), commonly used for this task. For , the is computed by averaging over all pixels of a class. Our metric-embedding has two layers of convolutions, with contrastive loss for . We randomly split each domain to for training and for testing in experimental settings.
Results. For easier comparison, we average the evaluated Dice scores achieved for the three foreground classes (GM/WM/CSF) and report it in Table 5. Although hard to notice visually from the gray-scale images, the domain shift from data distribution degrades segmentation significantly by up to . DeepAll is a strong baseline, yet our model-agnostic learning scheme provides consistent improvement over naively aggregating data from multiple sources, especially when generalizing to a new clinical site with relatively poorer imaging quality (i.e., Set-D). Figure 3 (c) is the Silhouette plot (rousseeuw1987silhouettes) of the embeddings from , demonstrating that the samples within the same class cluster are tightly grouped, as well as clearly separated from those of other classes.
We have presented promising results for a new approach to domain generalization of predictive models by incorporating global and local constraints for learning semantic feature spaces. The better generalization capability is demonstrated by new state-of-the-art results on popular benchmarks and a dense classification task (i.e., semantic segmentation) for medical images. The proposed loss functions are generally orthogonal to other algorithms, and evaluating the benefit of their integration is an appealing future direction. Our learning procedure could also be interesting to explore in the context of generative models, which may greatly benefit from semantic guidance when learning low-dimensional data representations from multiple sources.
This project has received funding from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant No 757173, project MIRA, ERC-2017-STG) and is supported by an EPSRC Impact Acceleration Award (EP/R511547/1). DCC is also partly supported by CAPES, Ministry of Education, Brazil (BEX 1500/2015-05).