Nowadays, deep convolution neural networks are widely used in different tasks and exhibit remarkable performance. However, their performance may degrade rapidly when the deployed environment is different from the training one, i.e., there exists covariate shiftsugiyama2007covariate between training and testing data as such the learned deep network can be biased. How to obtain a domain invariant network that can generalize to the data collected from an unseen environment is always a research hotspot zhang2016understanding; kawaguchi2017generalization.
Domain generalization (DG) aims to tackle the generalization problem where the data from the target domain is inaccessible. The significance of domain generalization attracts many studies recently in different streams. For example, domain alignment li2018domain; li2018deep aims to regularize the latent features from different domains to follow the same distribution, as such, shareable information can be explored and the learned model can be better generalized. Data augmentation based methods volpi2018generalizing; zhang2019unseen; wang2020heterogeneous; zhou2020deep have also been proved to be effective in the domain generalization setting by either adversarial augmentation or mixup. Recently, meta-learning based methods li2018learning; balaji2018metareg have also been investigated for better generalization capability of the learned networks. Despite some success has been achieved so far, how to efficiently utilize the multiple source domain data to achieve better generalization capability of the model is still an open problem.
In this work, we propose to conduct feature disentanglement to extract domain-specific and task-specific information for domain generalization purpose. Particularly, we deliver a novel framework named Variational Disentanglement Network (VDN) to minimize the information gain term of the latent feature, and maximize the posterior probability of the generated samples. More specifically, for the information gain term, we propose to minimize the KL divergence between the latent features of deep neural network and the pre-defined distribution by variational divergence minimizationnowozin2016f
, which we empirically find can lead to better performance of domain generalization. In addition, for the posterior probability term, we design a novel training mechanism of the framework to improve the efficiency of sampling under the latent distribution for novel sample generation and to estimate the posterior probability of a generated sample even if we are not able to obtain the corresponding ground-truth. The generated samples which can contain some novel attributes are also used as the augmented data for the training of the task-specific network. We also provide a theoretical analysis of our proposed framework from the perspective of Variational Bayesian method to show that our method is equivalent to minimize the evidence upper bound of the divergence between the distribution of task-specific features and its invariant ground truth. Extensive experiments are conducted to verify our method, and both quantitative and qualitative results show the effectiveness of our method.
The core ideas of domain generalization (DG) are inherited from domain adaptation huang2006correcting; pan2011domain; zhang2015multi; ghifary2017scatter; blanchard2021domain to some extent, e.g., they assume that there exists something in common between different domains even if they can look quite different. For example, khosla2012undoing; seo2020learning
aims to jointly learn an unbiased model and a set of bias vectors for each domain,yang2013multi used Canonical Correlation Analysis (CCA) to extract the shared feature, muandet2013domain proposes a domain invariant analysis method which used MMD and was further extended by li2018domain
, multi-task autoencoders were also used byghifary2015domain to learn a shared feature extractor with multiple decoders. Moreover, various regularization methods of latent code are proposed zhao2020domain; wang2020learning, e.g., low-rank regularization xu2014exploiting; li2017deeper; li2020domain; piratla2020efficient. In addition, data augmentation based methods are also proved to be effective in the domain generalization setting, e.g., GAN generated samples volpi2018generalizing; zhou2020deep, domain mixup wang2020heterogeneous; zhou2021domain, stacked transformations zhang2019unseen, domain-guided perturbation direction shankar2018generalizing and solving a jigsaw puzzle problem carlucci2019domain. Meta-learning based methods li2018learning; balaji2018metareg; li2019feature; li2019episodic; dou2019domain; du2020learning are also explored by learning from episodes that simulate the domain gaps. Recently, invariant risk minimization (IRM) is proposed to eliminate spurious correlations as such the models are expected to have better generalization performance arjovsky2019invariant; ahuja2020invariant; bellot2020accounting; zunino2020explainable. As for feature disentanglement, current methods are usually based on decomposition li2017deeper; khosla2012undoing; piratla2020efficient; chattopadhyay2020learning. Generation based disentanglement methods are also explored. For instance, peng2019domain conducts the single domain generalization using adversarial disentangled auto-encoder and wang2020cross provide a pair of encoders for disentangled features in each domain.
Feature disentanglement and image translation
Our proposed method is related to the feature disentanglement, which has also been widely adopted in the problem of cross-domain learning bousmalis2017unsupervised; hoffman2018cycada; russo2018source; saito2018maximum. A lot of progress in domain generalization has been made by applying feature disentanglement peng2019domain; khosla2012undoing; li2017deeper; piratla2020efficient. For instance, peng2019domain proposes a domain agnostic learning method based on VAE and adversarial learning, and khosla2012undoing; li2017deeper; piratla2020efficient assume that there exists a shared model which is regarded as domain invariant and a set of domain-specific weights.
Our work is also related to the image translation isola2017image; choi2018stargan; zhu2017unpaired; liu2019few, which can be treated as conducting feature disentanglement by separating the style feature and content feature. Generally, the existing translation methods can be categorized into two streams, supervised/paired isola2017image and unsupervised/unpaired zhu2017unpaired; choi2018stargan. While some progress has been achieved, most of the existing models require a large amount of data. Recently, some efforts have been made to improve the capability of the translation model by utilizing few samples liu2019few; saito2020coco. Specifically, the framework we propose conducts translation tasks in a more efficient manner given limited and diverse data. This helps us obtain a more accurate posterior probability estimation. In addition, the generated high-quality samples with different combinations of image attributes (e.g., samples with a new combination of angles, shape, and color zhao2018bias) based on our proposed framework are also used for data augmentation purpose for better generalization capability.
Assume we observe domains and there are labeled samples in domain . The distributions of input image and its corresponding label in the domain and all source domains are represented as and respectively. Meanwhile, we assume that the latent feature of a random image can be represented by where denotes the task-specific feature, i.e., the feature for classification in our paper, which is domain-invariant and
is the domain-specific feature which encodes the task-independent style code, i.e., the random variablesand are independent. In domain generalization task, we aim to learn a domain agnostic model that can generalize well in the unseen target domain based on the assumption that there exists an invariant feature representation among different domains. More specifically, while the distribution of may vary a lot, there exists an invariant causal mechanism, i.e., the distributions are the same among different domains. To this end, we propose to disentangle the task-specific feature and domain-specific feature by maximizing the the posterior probability of generated samples with random styles. In addition, an information gain term is proposed to use as a regularization term which has the similar core idea with li2018domain.
Our proposed method
In this section, we introduce the overall framework and the optimization details of our proposed method.
The whole framework of our proposed method is illustrated in Fig. 2 which consists of a task-specific encoder , a domain-specific encoder , a generator , a discriminator to distinguish real and generated images, a discriminator to justify whether the task-specific feature comes from a predefined distribution and a task-specific network for classification purpose. The overall optimization objective is given as
where the loss is defined in Eq. (6) in which acts as a regularization term that aligns the task-specific feature to a predefined distribution and is defined as
where , and are defined in Eq. (8), Eq. (9) and Eq. (10) in Sec. Optimizing the posterior probability term respectively and are optimized jointly to disentangle the task-specific feature and domain-specific feature .
In the test phase, an image from an unseen but related target domain is pass through the task-specific encoder and the task-specific network , and the other networks will not be involved. Therefore, our network is supposed to have the same inference complexity as the vanilla model.
Optimizing the information gain term
To reduce the risk of learned task-specific feature overfitting to the source domains, similar to li2018domain, we minimize the divergence between which is introduced by the task-specific encoder , and a predefined prior distribution . The divergence can be interpreted as the information gain by introducing a specific image . Ideally, the optimal features only contain the necessary information for classification, and domain-specific information will be ruled out. To this end, we minimize the task-specific loss, i.e., classification loss in Eq. (10), and information gain term in our framework simultaneously. For the minimization of information gain term, we empirically find that directly optimizing this term using the reparameterization trick kingma2013auto may not be feasible due to the high dimensionality of . To this end, we optimize its dual form following the core idea of F-GAN nowozin2016f given as
where , denotes its Fenchel conjugate hiriart2004fundamentals which is defined as
One can replace in Eq. (4) using an arbitrary class of functions , as such, the term ① can be represented as
if the capacity of is large enough.
To optimize Eq. (5), an adversarial training manner is used. More specifically, we optimize the task-specific encoder to maximize this term and the discriminator to minimize it. The final goal of the training for the regularization term is as follows
where the function in Eq. (5) is implemented by a deep neural network
with the activation functionat the output to retain the original domain .
Optimizing the posterior probability term
For better disentanglement, we propose to maximize the posterior probability term . The advantages of maximizing this term can be roughly summarized into two points: first, directly minimizing the information gain term and task-specific term may cause the overfitting to the training samples. For instance, when the dataset is not large enough, directly memorizing all the datasets may has less information gain comparing with extracting discriminate features for classification. By maximizing the posterior probability, we can guarantee that and together contain almost all the information of the image which avoid the loss of discriminative features caused by minimizing the information gain term. Second, by improving the quality of generated samples using random combined task-specific feature and style feature , the task-specific and domain-specific features will be disentangled since our discriminator can not only distinguish real and fake images but can also differentiate whether the generated samples are from the specific domains thanks to the domain labels.
To maximize the posterior probability term, two obstacles need to be solved. First, an efficient sampling strategy is needed on account that the space of and is large and intractable and it is not feasible to sample ergodically from them. To this end, inspired by VAE kingma2013auto, we propose to adopt the task-specific encoder and domain-specific encoder instead, i.e., and , to conduct code sampling of and which are most likely to generate a realistic sample. To sample in and independently, we shuffle the domain-specific features in a batch and ensure the generated samples using randomly combined features as realistic as possible. More details can be found in Algorithm 1.
Unlike VAE kingma2013auto which only computes the reconstruction term , we also need to compute the probability of the generated one using the generator for without a corresponding ground-truth. To this end, we estimate the probability using the following equations
where denotes the Laplace distribution (corresponding to the L1 norm) to measure the pixel reconstruction loss. To optimize the term in Eq. (7), needs to have the capability to distinguish between the real and generated images through random combined latent features and the generator is also required to produce realistic outputs. To this end, we train the model in an adversarial manner which can be given as
For the reconstructed images with ground-truth, we empirically find that directly minimizing the pixel level divergence can lead to classification performance degradation. To this end, we minimize the divergence of the semantic feature through a pretrained perceptual network instead. Therefore, we maximize the improved version by minimizing its corresponding L1 loss given as
Last but not least, to further enforce the task-specific encoder to generate more meaningful embeddings, we use label information to guide the model training by minimizing the following equation
where denotes the ground-truth label of the input or its task-specific feature , is the task-specific loss, e.g., cross entropy is used in our work for classification tasks. Besides, we also consider the generated samples for data augmentation purpose, which is the second term of Eq. (10).
In this section, we give an theoretical insight of our proposed method in the perspective of Variational Bayesian methods.
The KL divergence can be represented as
Based on Lemma 1, we can derive the evidence upper bound of .
The evidence upper bound of KL divergence between the distribution and the ground-truth is as follows:
where is a constant and and can be an arbitrary prior distribution. The proof is placed in the Appendix for limited space.
In summary, our proposed method is an effective way to minimize this upper bound. Theorem 1 shows that we can minimize the evidence upper bound of the divergence between the distribution of our extracted features and its invariant ground-truth by minimizing ① information gain term and maximizing ② posterior probability term simultaneously. The term ① is optimized by Eq.(6) and the term ② is optimized by Eq.(2).
We evaluate our method using three benchmarks including Digits-DG zhou2020deep, PACS li2017deeper and mini-DomainNet zhou2020domain; peng2019moment. For the hyper-parameters, we set , and as 0.1, 1.0 and 1.0, respectively, for all experiments. The details of the network architecture are illustrated in the supplementary materials. We report the accuracy of the model at the end of the training.
Settings: Digits-DG zhou2020deepzhou2020deep. More specifically, the input images are resized to 3232 based on RGB channels. The ConvNet is used as the backbone for all methods and is divided into two parts: the task-specific encoder
which includes the first three conv blocks except for the last max-pooling layer and the rest are regarded as the task-specific network. For the and , we use the same learning parameters with zhou2020deep
. SGD is used as the optimizer with an initial learning rate of 0.05 and weight decay of 5e-5. For the rest of the networks, we use RmsProp without momentum as the optimizer with the initial learning rate of 5e-5 and the same weight decay. We train the model for 60 epochs using the batch size of 128 and all the learning rate is decayed by 0.1 at the 50th epoch. We updateand once after every 10 updates of other parts in the framework.
Results: The methods for comparison include DeepAll, CCSA yoon2019generalizable, MMD-AAE li2018learning, CrossGrad shankar2018generalizing, and DDAIG zhou2020deep
. We repeat the experiment 5 times and report the average accuracy and 95% confidence intervals based on the unseen target domain in Table1. The results demonstrate that our method achieves the best overall performance by a large margin, especially in the domain of SYN with more than 3% improvement. In addition, our method is more stable than the second-best DDAIG which also uses GAN to do the data augmentation.
Settings: PACS li2017deeper is a benchmark for domain generalization task collected from four different domains: photo, art painting, sketch, and cartoon with relatively large domain gaps. Following the widely used setting in carlucci2019domain, we only used the official split of the training set to train the model and all the images from the target domain are used for the test. RmsProp is used to train all the networks with an initial learning rate of 5e-5 without momentum and decrease the learning rate by a factor of 10 at the 50th epoch. The batch size is set to 24 and we sample the same number of images from each domain at the training phase. All the images are cropped to 224224 for training and it is worth noting that we only use the data augmentations including random crop with a scale factor of 1.25 and random horizontal flip. Other augmentations such as random grayscale are not used as it may conceal the true performance of the model by introducing prior knowledge of the target domain. More specifically, the part from the beginning to the second residual block inclusive is regarded as the task-specific encoder , and the remaining part acts as the task-specific network . The discriminators and are updated once after every 5 updates of other parts in the framework.
Results: The results based on Resnet-18 are reported in Table 2. As we can observe, our method outperforms other state-of-the-art methods. Moreover, we observe that we can achieve much better performance in the sketch domain in a large margin compared with other baseline methods. We conjecture the reason that the Sketch domain may contain less domain-specific information. As shown in Fig. 4, shuffling domain-specific features has less impact on image generation in the domain of Sketch. Due to limited space, the results based on Resnet-50 are reported in the Appendix.
Settings: We then consider a larger benchmark mini-DomainNet zhou2020domain, which is a subset of DomainNet peng2019moment, for evaluation purpose. In mini-DomainNet, there are more than 140k images in total belonging to 4 different domains, namely sketch, real, clipart, and painting.
For the task-specific encoder and task-specific network , we use SGD with a momentum of 0.9 and an initial learning rate of 0.005 as the optimizer. For other parts of our framework, RmsProp without momentum is used with the initial learning rate of 0.0001. For the learning rate scheduler towards all the optimizers, we use the same cosine annealing rule loshchilov2016sgdr with the minimum learning rate of 0 after 100 epochs. The batch size is 128 with a random sampler that roughly samples the same number of images in each domain. We consider data augmentations including the random clip with a probability of 0.5, and random crop the data to the size using the scale factor of 1.25. For a fair comparison, we use the same backbone network Resnet-18 for all the methods and the division of and for our model is the same as the setting in PACS. The update frequency of and is the same with PACS that discriminators update once after every 5 updates of other parts in the framework.
Results: We compare our method with MLDG li2018learning, JiGen carlucci2019domain, MASF dou2019domain and RSC huang2020self. The results are shown in Table 3. As we can observe, we can achieve better performance compared with other baselines, especially in the sketch domain in a large margin. There is an interesting finding that similar to the PACS benchmark, the performance improvement in the sketch domain is huge, but in the real domain, the performance of our method has some degradation. This may reveal the potential inductive bias of the model. However, our method still has the best overall performance.
Ablation study and perceptual results
In this section, we first present the results of the ablation study to illustrate the effectiveness of each component in our proposed method. We further provide some perceptual results of image generation to show the significance of feature disentanglement.
In this section, we conduct the ablation study using PACS benchmark. We first explore the effectiveness of each term in the evidence upper bound we proposed in Theorem 1. In addition, the impacts of different optimization strategies for each term are evaluated.
Effectiveness of each term: We first evaluate the effectiveness of optimizing each term we proposed in the EUB. The results are shown in Table 4 where ’+①’ and ’+②’ represent that we utilize the information gain term ① defined in Eq. (6) and the posterior probability term ② defined in Eq. (2) but without data augmentation respectively. ’+DataAug’ means we use the generated samples to do data augmentation. The results demonstrate that both the information gain term and posterior probability term can improve the generalization capability of the model. In addition, combining these two together can attain larger performance improvements. The results also demonstrate that using the generated samples as augmented data can effectively improve the generalization ability of the model.
Impacts of different optimization strategies: As for the evidence upper bound we proposed, there are different optimization strategies for each term. For instance, reparameterization trick kingma2013auto is a widely used method to optimize the term ①. In addition, directly optimizing the L1 reconstruction loss is a usual way to generate sharp reconstructed images. We investigate the impacts of different optimization strategies and the results are shown in Table 5 where the ✓one is the strategy we adopt. As we can see, directly using the reparameterization trick to align the high-dimensional features can lead to side effects. In addition, optimizing the perceptual loss can lead to better performance compared with replacing the reconstruction loss in Eq. (9) with L1 loss.
To provide an intuitive way to understand the effect of disentangling, we further give some perceptual results. For limited space, more visualization results are placed in the supplementary materials.
Generated samples based on source-domain images: To demonstrate the effectiveness of our method, we first visualize the generated samples in a cross-domain setting that the pairs of input samples are from different source domains on account that the quality of the samples can reflect the accuracy of the estimated posterior probability and the degree of disentanglement. Some of the generated samples are shown in Fig. 3. The visualization results demonstrate that our method can disentangle the domain-specific features and task-specific features well and generate realistic novel samples with high quality and different styles. In addition, we observe that the reconstructed samples may not necessarily be the same as the original one, mainly due to the perceptual loss we adopt, as such, we can prevent the latent features from overfitting to the source domain.
Generated samples based on target-domain images: To further demonstrate the effectiveness of our proposed framework, we conduct image generation based on the unseen target domain, where the generated samples using the pairs from the same unseen target domain are shown in Fig. 4
based on leave-one-domain-out training manner. From the visualization results, we find that our method can still separate the task-specific features and domain-specific features well even if the networks have never seen the samples from the target domain. More specifically, it can encode the intra-domain style variance based on the observation that the model can generate samples with different styles using the domain-specific features from the same target domain.Meanwhile, the results in Fig.4 also illustrate that the sketch domain may have little domain-specific features and intra-domain style variance on account that the translated and reconstructed samples are almost the same. This observation further demonstrates the effectiveness of our proposed method by using Sketch as the target domain where a significant performance improvement can be achieved.
In this paper, we propose to tackle the problem of domain generalization from the perspective of variational disentangling. Specifically, we first propose an efficient framework to minimize the information gain introduced by a specific image and disentangle the task-specific and domain-specific features simultaneously. Then, an analysis is given from the perspective of variational inference and it demonstrates that our framework actually minimizes the evidence upper bound regarding the divergence between the distribution of task-specific features and its invariant ground-truth. Extensive experiments are conducted to verify the significance of our proposed method. Besides, we conduct experiments of image generation which further justify the effectiveness of our proposed disentanglement framework.