1 Background and Motivation
One big challenge of deploying a neural network model in real world use-cases is domain shift. In many real world applications, data seen by a deployed model is drawn from a distribution that is different from the training distribution and often unknown at train time. Domain Generalization aims at training a model from a set of domains (i.e. related distributions) such that the model is able to generalize to a new, unseen domain at test time.
Domain generalization is relevant for a variety of tasks, ranging from personalized medicine, where each patient corresponds to a domain, to predictive maintenance in the context of industrial AI. In the latter use-case, domains can represent different factories where an industrial asset (e.g. a tool machine or a turbine) is operated, or different workers operating the asset. In addition to these discrete domains, domain shift can manifest itself in a continuous manner, where for example the data distribution seen by an industrial asset can change due to wear and tear or due to maintenance procedures.
In many of these use cases, interpretability and human oversight of machine learning models is key. Generative models allow for learning disentangled representations that correspond to specific and interpretable factors of variation, thereby facilitating transparent predictions.
We propose a new weakly-supervised generative model that solves domain generalization problems in an interpretable manner. We build on previous work using autoencoder-based models for domain generalization [kingma2013auto, ilse2019diva] and propose a Hierarchical Domain-Invariant Variational Autoencoder that we refer to as HDIVA. Our major contributions include:
We present a weakly supervised algorithm for domain generalization that is able to account for incomplete and hierarchical domain information.
Our method is able to learn representations that disentangle domain-specific information from class-label specific information even in complex settings
Our algorithm generates interpretable domain predictions that reveal connections between domains.
2 Related work
In general, domain generalisation approaches can be divided into five main categories, that we descirbe below.
Invariant Feature Learning
While observations from different domains follow different distributions, Invariant Feature Learning approaches try to map the observations from different domains into a common feature space, where domain information is minimized [xie2017controllable, akuzawa2018domain]
. The method works in a mini-max game fashion in that there is a domain classifier trying to classify domains from the common feature space, while a feature extractor tries to fool this domain classifier and help the target label classifier to classify class label correctly. In[li2017deeper]
, the authors present a related approach and use tensor decomposition to learn a low rank embedding for a set of domain specific models as well as a base model. We classify this method into invariant feature learning because the base model is domain-invariant.
Image Processing based method
In [carlucci2019domain], the authors divide the image into small patches and generate permutations of those small patches. They then use a deep classifier to predict the predefined permutation index so that the model learns the global structure of an image instead of local textures. In [wang2019learning], the authors use a gray level co-occurence matrix to extract superficial statistics. They present two methods to encourage the model to ignore the superficial statistics and thereby learn robust representations. This group of methods has been developed for image classification tasks, and it is not clear how it can be extended to other data types.
Adversarial Training based data augmentation
In [Volpi2019] the authors optimize a procedure to search for worst case adversarial examples to augment the training domain. In [volpi2018generalizing], the authors use Wasserstein distance to infer adversarial images that are close to the current training domain, and train an ensemble of models with different search radius in terms of Wasserstein distance.
Meta Learning based method
MLDG [li2018learning] use model agnostic training to tackle domain generalization as a zero-shot problem, by creating virtual train and test domains and letting the meta-optimizer choose a model with good performance on both virtual train and virtual test domains. In [balaji2018metareg], the authors improve upon MLDG by concatenating a fixed feature network with task specific networks. They parameterize a learnable regularizer with a neural network and train with a meta-train and a meta-test set .
Auto-encoder based method
DIVA [ilse2019diva] build on variational auto-encoders and split the latent representation into three latent variables capturing different sources of variation. We introduce DIVA into more detail in Section 3.2. In [hou2018cross], the authors encode images from different domains in a common content latent code and domain-specific latent code, while the two types of encoders share layers. Corresponding discriminators are used to predict whether the input is drawn from a prior distribution or generated from encoder.
Comparing these family of approaches, we can see that only probabilistic auto-encoder based models inherit advantageous properties like semi-supervised learning, density estimation, variance decomposition naturally. While autoencoder-based approaches such as DIVA have a better interpretability than all other approaches, a main drawback is that explicit domain labels are required during training and other drawbacks which we will elaborate in later sections.
3 Methods and Technical Solution
3.1 Problem statement and notation
Domain generalization aims to generalize models to unseen domains without knowledge about the target distribution during training. A domain
consists of a joint distributionon , with being the input space and being the output space [muandet2013domain, ilse2019diva]. For our modelling approach, we employ the framework of variational autoencoders (VAEs) [kingma2013auto]. We use to represent the latent representation of a VAE. We follow [ilse2019diva] and use three independent latent representations to disentangle variability in inputs
related to domain-specific sources, label-specific sources and residual variation. We use probabilisticgraphical models to illustrate the conditional dependecies of random variables, observables and hyperparameters. In graphical models, solid circles represent observations and white circles represent latent variables. We use half-shaded circles to represent a variable can either be observed or act as latent variable, which is typical in semi-supervised learning. Small solid circles in Figure1 and 2 represent fixed hyper-parameters. Subscripts represent components of a variable, while we use super-script to index samples and domains. We use solid arrows to represent generative path, and dashed arrows to represent variational inference part. Plates represent repetitions of random variables. We use to represent learnable parameters of priors/decoders and to represent learnable parameters of variational posterior distributions/encoders.
3.2 From DIVA to HDIVA
We first give a brief introduction to DIVA [ilse2019diva], which forms the basis of our model. In DIVA, three latent variables are used to model distinct sources of variation that are denoted as , and . represents class specific information, represents domain specific information and models residual variance of the input. DIVA [ilse2019diva] encourages disentanglement among these latent representations using conditional priors and with learnable parameters and for and , respectively. For a standard isotropic Gaussian prior is chosen.
All three latent variables are then used to reconstruct using a neural-network based decoder .
As auxiliary components, DIVA adds a domain classifier based on , as well as a target class label classifier based on . This setup is illustrated in a graphical model, where we use dotted lines to represent these auxiliary classifiers (Figure 1). While we leave the semi-supervised comparison study for DIVA and our algorithm as future work, we use half shaded variables in our graphical model in Figure 1
and Figure 2 for completeness. We use to indicate domains, with , where is the number of domains training time.
Since exact inference is intractable in such an autoencoder, [ilse2019diva] perform variational inference and introduce three separate decoders , and , as fully factorized approximate posterior. The decoders are parameterized with , and , respectively, using neural networks.
While DIVA does not require a domain label at test time, domains
need to be explicitly one-hot encoded for training. This can be problematic in a number of settings. In particular, a one-hot encoding of domains does not reflect scenarios where a continuous domain shift can occur. In this case, without knowledge of the causal factor that causes the domain shift, it is not clear how such continuous shifts can be one-hot encoded in a meaningful manner. In addition,
Domains can have a hierarchical structure reflected by related sub-domains (e.g. country factory machine). One-hot encodings as used in DIVA are not able to model such hierarchical domain structures.
In some applications, domains are not necessarily well-separated, but significant overlap between domains can occur (e.g. a cartoon might look more similar to a pop-art painting than a photography). One-hot encoding such overlapping domains encourages separated representations, which may harm model performance.
A one-hot encoding of domains mapping to the prior distribution of may limit the generalization power of neural networks, especially when we deal with continuous domain shift.
Finally, DIVA requires an observed domain label. In practice, this label may not be readily available. Recently, [Shu2019] discuss different scenarios where only limited information about the data-generating process is conveyed through additional observation and present a weakly supervised approach for representation learning via distribution matching.
In the same spirit of weakly supervised learning we introduce a novel approach for weak domain supervision. Like DIVA, our model learns disentangled representations for domain generalization. However, we extend DIVA to overcome some of its main limitations stemming from the use of one-hot encodings. To this end, we introduce a latent topic-like representation of domains. These topics are able to capture hierarchical domain structures, can naturally describe continuous domain shift and deal with incomplete domain information at training time in an interpretable manner.
3.3 HDIVA overview
To overcome the limitations of DIVA mentioned in section 3.2, we propose a hierarchical probabilistic graphical model called weakly-domain-supervised Hierarchical Domain Invariant Variational Autoencoder (HDIVA). We introduce a new hierarchical level to the DIVA base-model and use a continuous latent representation to model (potentially incomplete) knowledge about domains. We place a Dirichlet prior on such that it can be interpreted as a soft, topic-like, version of the one-hot encoded domain in DIVA. We then use to capture domain-specific variation by conditioning its prior on . Note that in our model this domain is no longer an observable but instead a latent variable to be inferred from data. For clarity, we refer to an observed domain as nominal domain. Borrowing from topic models in NLP [srivastava2017autoencoding], we refer to as topic. We use topics to enable weakly supervised domain generalization, which is explained in more detail in the following and in section 3.5. We illustrate HDIVA in form of a graphical model in Figure 2. We use
to denote the dimension of the domain representation or topic vector, i.e. . We use to index each component of , i.e. , with indexing a domain as before. Note that in our case, can be either greater, smaller or equal to the number of domains , while in DIVA [ilse2019diva], the one-hot encoded domain label is always the size of . This is beneficial for problems with a large number of domains which lie on a lower-dimensional manifold (e.g. thousands assets in an predictive maintenance task). In this case, when choosing the topic dimension
to be smaller than the number of training domains, our algorithm can be interpreted as an eigen-domain decomposition algorithm. We use stochastic gradient descent to train our model. Accordingly, in Figure2, the batch size is denoted by for the th domain, with a total of batches for domain . We use to index a batch and to index a sample. For simplicity, and are omitted whenever convenient and not causing confusion.
3.4 Model implementation
In this section, we first describe the generative model with prior distributions, followed by a discussion on model inference.
3.4.1 Prior distributions for , and
We follow [ilse2019diva] and chose an isotropic Gaussian prior with zero mean and unit variance for and conditional priors for for and .
More specifically, we chose a normal prior for that is conditioned on the target class label :
being learnable parameterizations of the mean and standard deviation in form of neural networks.
Similarly, we choose a normal prior for and condition it on :
where again and parameterize mean and variance of .
3.4.2 Prior distribution for
We would like for to display topic-like characteristics, facilitating interpretable domain representations. Consequently, we use a Dirichlet prior on , which is a natural prior for topic modeling [srivastava2017autoencoding].
Let be the Dirichlet concentration parameter , then the prior distribution of can be written as:
where we use to represent the partition function.
We do not learn the distribution parameter , but instead, leave it as a hyper-parameter. By default, we set
to be a vector of ones, which corresponds to a uniform distribution of topics. We refer to this prior setting as flat prior. If more prior knowledge about the relation between training domains is available, we use an informative prior instead.
We factorize the approximate posterior as follows:
For the approximate posterior distributions of , and , we follow [ilse2019diva] and assume fully factorized Gaussians with parameters given as a function of their input:
Encoders , and are parameterized by , and using separate neural networks to model respective means and variances as function of .
For the form of the approximate posterior distribution of the topic we chose a Dirichlet distribution:
where parameterizes the concentration parameter based on , using a neural network.
Given the priors and factorization described above, we can optimize the model parameters by maximizing the evidence lower bound (ELBO). We can write the ELBO for a given input-output tupel as:
where we use to represent the multiplier in the Beta-VAE setting [higgins2016beta], further encouraging disentanglement of the latent representations.
As in [ilse2019diva], we add an auxiliary classifier , which is parameterized by , to encourage separation of classes in . The HDIVA objective then becomes:
To efficiently perform inference with the dependent stochastic variables and , we follow [sonderby2016ladder] and adapt the ELBO using the Ladder VAE approach as detailed in the next section.
3.4.5 Dealing with dependent stochastic variables
The joint posterior can be written as:
where conditional independence of from is assumed. As pointed out in [chen2016variational, tomczak2018vae], this can lead to inactive stochastic units. We follow [sonderby2016ladder] and recursively correct the generative distribution by a data dependent approximate likelihood. Additionally, we implement a deterministic warm-up period of following [sonderby2016ladder, ilse2019diva], in order to prevent the posterior of the latent representation from aligning too quickly to its prior distribution.
3.5 Weak Supervision on domains
In many scenarios only incomplete domain information is available. For example, due to privacy concerns, data from from different customers within a region may be pooled so that information on the nominal domain at customer-level is lost and only higher-level domain information is available. In other settings, substantial heterogeneity may exist in a domain and various unobserved sub-domains may be present. We introduce two techniques for weak supervision on domains, allowing the model to infer such lower-level domains or sub-domain information in the form of a topic .
3.5.1 Topic Distribution Aggregation
To indicate that a group of samples ”weakly” belong to one domain, we aggregate the concentration parameter of the posterior distribution of for all samples in a minibatch (note that all samples in a minibatch have the same nominal domain):
We then use the aggregated concentration parameter to sample a topic from a Dirichlet distribution:
The conditional prior of (equation 2) then shares this same topic for all samples in the th mini-batch. We interpret this topic-sharing across samples in a mini-batch as a form of regularized weak supervision. In one-hot encoded approaches, all samples from the same nominal domain would share the same topic. In contrast, sharing a topic in the conditional prior of the latent representation across samples in a mini-batch provides a weak supervision, whilst allowing for an efficient optimisation via SGD. Note that concentration parameters for a mini-batch are only aggregated during training, at test time sample-specific posterior concentration parameters are used.
3.5.2 Weak domain distribution supervision with MMD
DIVA encourages separation of nominal domains in the latent space by fitting an explicit domain classifier which might limit model performance in the case of incomplete domain information. To mitigate these limitations but still weakly encourage separation between different nominal domains, we constrain the HDIVA objective based on the Maximum-Mean-Discrepancy (MMD) [gretton2012kernel] between pairwise domains.
Denoting as the minimal distance computed by MMD as an inequality constraint, we can write the constraint optimization of equation 8 as follows:
3.6 Practical considerations
In practice, we transform the constrained optimization in Equation 12 with a Langrange Multiplier. This leads to the final loss in Equation 13, where denotes the Lagrange multiplier for (c.f. Equation 12):
4 Empirical Evaluation
We conduct experiments, trying to answer the following questions:
Could HDIVA mitigate the limitations of DIVA mentioned in section 3.2, especially in terms of domain-substructure or overlap between nominal domains? We conduct experiments in Section 4.1 and Section 4.3 to address these issues.
In complex scenarios with domain substructure, can HDIVA still robustly disentangle domain-specific variation from class-label specific variation? See details in Section 4.1.
We visualize topics from overlapping nominal domains to illustrate why HDIVA improves upon DIVA in Section 4.3.
How do DIVA and HDIVA perform under standard domain generalization benchmarks compared with other state-of-the-art algorithms? See Section 4.4.
4.1 Subdomains inside nominal domains
To simulate domains with sub-structures, we create sub-domains within nominal domains. All sub-domains within one nominal domain share the same domain label. We adapt color-mnist[metz2016unrolled, rezende2018taming] with the modification that both its foreground and background are colored as sub-domain, as shown in Figure 3. We use the top 3 sub-domains as the first nominal domain, the middle 3 sub-domains as the second nominal domain, and the bottom 3 sub-domains as the 3rd nominal domain. Thus, we constructed 3 nominal domains with sub-structures. For DIVA, we use a one-hot encoded nominal domain label as explicit domain label. For HDIVA, we only use this nominal domain label for weak supervision as explained in section 3.5.1 and 3.5.2.
We are interested in evaluating how DIVA and HDIVA would behave under this sub-domain scenario, in terms of out-of-domain prediction accuracy and disentanglement performance. We perform a leave-one-domain-out evaluation [li2017deeper], where each test domain is repeated 10 times with 10 different random seeds. We report the out of domain test accuracy in Table 1. Table 1 shows that HDIVA outperforms DIVA in terms of out of domain performance on all three test domains, while retaining a very small variance compared to DIVA.
|Test Nominal Domain 1||Test Nominal Domain 2||Test Nominal Domain 3|
|DIVA||0.76 0.06||0.486 0.18||0.470 0.07|
|HDIVA||0.93 0.02||0.845 0.06||0.506 0.04|
We further evaluate how robustly DIVA and HDIVA are able to disentangle different sources of variation under this scenario with incomplete sub-domain information.
We sample seed images from different sub-domains as shown in the first row of Figure 4. We then generate new images by scanning the class label from 0 to 9 by sampling from the conditional prior distribution of (i.e. , eq. 1). We keep the domain representation the same as in the seed image, set the noise component to zero and then use the decoder network to generate an image based on the three latent representations. If the models are able to disentangle domain-specific variation from class-label specific variation in and , we expect that the generated images have the same domain information as the seed image (foreground and background color) while generating different class labels (numbers from 0 to 9). In Figure 4 we compare DIVA and HDIVA’s generative performance. Due to the sub-structure inside the nominal domains, DIVA could only reconstruct a blur of colors for the first 3 columns in Figure (a)a, while HDIVA could generate different numbers for 2 of the three seed images. For the last seed image, both DIVA and HDIVA could conditionally generate numbers, but DIVA did not retain the domain information (since the background color, which is dark blue in the seed image, is light blue in the generated images). This indicates that DIVA is not able to disentangle the different sources of variation and domain information is captured by as well. In contrast, HDIVA was able to separate domain information from class-label information.
4.2 Composite Overlapped near continuous Domain Shift
We consider a composite overlapped domain scenario. From the python library Seaborn, we use the diverging color palette as background and use the hls color palette as foreground to construct a color-mnist scenario, where each palette consist of 7 colors. The overlapped color-mnist is drawn in Figure 5, where each row corresponds to one sub-domain. We construct nominal domains by taking the first 3 rows as one nominal domain, then starting from the third row with the following consecutive 3 rows as the second nominal domain. That is, the subdomain corresponding to the third row is the overlap between the first nominal domain and the second nominal domain. We take the last 3 rows as the third nominal domain, with the overlap with the second nominal domain being the 3rd to last row.
Following leave-one-domain out setting as in other experiments, we report the out-of-domain classification accuracy in Table 2, where we found that in two out of three test domains, HDIVA significantly wins DIVA, while in another test domain setting, DIVA and HDIVA performs within error range.
|Test Nominal Domain 3||Test Nominal Domain 1||Test Nominal Domain 2|
|DIVA||0.603 0.04||0.464 0.02||0.959 0.002|
|HDIVA||0.653 0.04||0.453 0.03||0.975 0.008|
4.3 Domain Embedding
We adapt the standard rotated MNIST benchmark [ilse2019diva] by introducing an overlap between three nominal domains: for the first nominal domain, we use 1000 samples of MNIST and rotate them by 15, 30 and 45 degrees respectively. Thus, the first domain contains 3000 instances and each rotation angle constitutes one sub-domain. For the second domain nominal domain, we rotate the same subset of MNIST, by 30, 45 and 60 degrees respectively. In this way, each nominal domains has two rotation degrees of overlap corresponding to 2000 instances that have the same rotation. We use these 2 nominal domains for training, and simulate a continuous domain shift for testing with rotation angles of 0, 22 and 75 degrees. We sampled images from both nominal training domains as well the continuously shifted test domain and plot their topic distributions in Figure 6. We expect the topics of the training domains to overlap substantially, due to the shared rotation angles. We further expect for the topics of the test domain to span the entire range of topics from both training domains. Figure 6 illustrates that HDIVA indeed assigns similar domain topics to many instances from both training domains, while samples from the test domain span the entire range of topis.
4.4 State of the art Domain Generalization benchmark
We finally compare HDIVA to DIVA and other state-of-the-art domain generalization algorithms for a standard domain generalization task. Table 3 shows algorithm performance on the PACS dataset [li2017deeper] which is a popular domain generalization benchmark. We found that image augmentation using random centered crop concatenated with random horizontal flip aids model performance and use AUG as a suffix to the algorithm to indicate when these image transformations are carried out.
Adapted from https://domaingeneralization.github.io/.
Table 3 shows that the performance of HDIVA is comparable to state-of-the-art performances on the PACS dataset on most of the test domains except for art-painting. Notably, HDIVA substantially outperforms DIVA also on this standard domain generalization task without known sub-domains. While overall performance of methods such as JIGSAW is consistently better than HDIVA, it is based on complex image manipulations. In contrast, HDIVA is an interpretable model that can be used for different data modalities and a larger number of tasks including domain prediction and sample generation.
We proposed an Hierarchical Domain Invariant Variational Autoencoder, with the following improvements: First, in the presence of domain-substructure, our algorithm is able to robustly disentangle domain-specific variation from class-label specific variation. Second, our algorithm is able to model domain overlap via interpretable topics and generalize to settings with continuous domain shift. Finally, our algorithm performs significantly better than DIVA on standard domain generalization tasks such as PACS.
6 Hyper-parameters and other experiment details
Since our model share a lot of components with DIVA [ilse2019diva], we use the same hyper-parameters suggested by [ilse2019diva]
for our DIVA implementation, as well as for HDIVA for the shared part. For example, the latent dimension for each latent code is taken to be 64, the classifier is taken to be a one layer neural network with Relu activation. All experiments are run with maximum 500 epochs with early stopping tolerance of 100 epochs based on a validation set taken from a 20% random split from each of the training domains. We use a learning rate of 1e-4 for both algorithms.in equation 8 is taken to be 1e5 for both algorithms. Warm-up of KL divergence loss in Equation 7 is taken to be 100 epochs for both DIVA and HDIVA, while the value for both algorithms are taken to be 1. We use topic dimension of 3 for HDIVA. in equation 13 is taken to be 1e5.
For experiments regarding MNIST, including MNIST rotation overlap 4.3, colored mnist combination 4.1, and domain overlapped color-mnist in 4.2, we use random sub-samples (each contains 1000 instances) pre-sampled from https://github.com/AMLab-Amsterdam/DIVA/tree/master/paper_experiments/rotated_mnist/dataset with commit hash tag ab590b4c95b5f667e7b5a7730a797356d124.