DeepAI
Log In Sign Up

KL Guided Domain Adaptation

06/14/2021
by   A. Tuan Nguyen, et al.
0

Domain adaptation is an important problem and often needed for real-world applications. In this problem, instead of i.i.d. datapoints, we assume that the source (training) data and the target (testing) data have different distributions. With that setting, the empirical risk minimization training procedure often does not perform well, since it does not account for the change in the distribution. A common approach in the domain adaptation literature is to learn a representation of the input that has the same distributions over the source and the target domain. However, these approaches often require additional networks and/or optimizing an adversarial (minimax) objective, which can be very expensive or unstable in practice. To tackle this problem, we first derive a generalization bound for the target loss based on the training loss and the reverse Kullback-Leibler (KL) divergence between the source and the target representation distributions. Based on this bound, we derive an algorithm that minimizes the KL term to obtain a better generalization to the target domain. We show that with a probabilistic representation network, the KL term can be estimated efficiently via minibatch samples without any additional network or a minimax objective. This leads to a theoretically sound alignment method which is also very efficient and stable in practice. Experimental results also suggest that our method outperforms other representation-alignment approaches.

READ FULL TEXT VIEW PDF

page 1

page 2

page 3

page 4

02/19/2020

Learning Bounds for Moment-Based Domain Adaptation

Domain adaptation algorithms are designed to minimize the misclassificat...
10/30/2022

Distributionally Robust Domain Adaptation

Domain Adaptation (DA) has recently received significant attention due t...
03/05/2019

Domain Adaptation with Asymmetrically-Relaxed Distribution Alignment

Domain adaptation addresses the common problem when the target distribut...
10/14/2019

Wasserstein Distance Guided Cross-Domain Learning

Domain adaptation aims to generalise a high-performance learner on targe...
05/08/2022

Efficient Representation of Large-Alphabet Probability Distributions

A number of engineering and scientific problems require representing and...
02/10/2020

Collaborative Training of Balanced Random Forests for Open Set Domain Adaptation

In this paper, we introduce a collaborative training algorithm of balanc...

1 Introduction

With advances in neural network architectures

He et al. (2016); Vaswani et al. (2017)

, machine learning algorithms have achieved state-of-the-art performance in many tasks such as object classification, object detection and natural language processing. However, machine learning models have been focusing mostly on the case of independent and identically distributed (i.i.d.) datapoints; and such an assumption often does not hold in practice. When the i.i.d. assumption is violated and the target domain has a different distribution compared to the source domain, a typical learner trained on the source data via empirical risk minimization would not perform well at test time, since it does not account for the distribution shift. To tackle this problem, many methods have been proposed for domain adaptation

Zhao et al. (2019); Zhang et al. (2019); Combes et al. (2020); Tanwani (2020) and domain generalization Khosla et al. (2012); Muandet et al. (2013); Ghifary et al. (2015), the goal of which is to train a machine learning algorithm that can generalize well to the target domain.

A common approach to tackle these problems is to learn a representation such that its distribution does not change across domains. There are two ways this can be learnt: marginal alignment (aligning the marginal distribution of the representation) and conditional alignment (aligning the conditional distribution of the label given the representation) Nguyen et al. (2021); Tanwani (2020)

. For domain adaptation and domain generalization problems with multiple source domains, we can use the data and labels to align both the marginal and the conditional distributions across the source domains, aiming to generalize to the target domain. However, in a single-source domain adaptation problem, with only unlabeled data from the target domain, it is only possible to align the marginal distribution of the representation. This marginal alignment should help the classifier avoid out-of-distribution data at test time.

This paper focuses on such a single-source domain adaption problem, which is also one of the most common settings in practice. Current marginal alignment techniques usually require additional computation (e.g., of an additional network) Ganin et al. (2016); Li et al. (2018) and/or a minimax objective Ganin et al. (2016); Shen et al. (2018), leading to an expensive and/or unstable training procedure Goodfellow (2016); Kodali et al. (2017). For example, DANN Ganin et al. (2016) employs an adversarial training procedure, with a domain discriminator that classifies the domain of the representation, and maximizes the adversarial loss of the discriminator. When the discriminator is completely fooled, the marginal distribution of the representation is aligned across domains. MMD Li et al. (2018) utilizes maximum mean discrepancy to align the representation distribution. This does not use a mimimax objective, thus leading to a more stable training; however, it does require additional computation of several Gaussian kernels.

To address the above issues, we derive a generalization bound on the loss of the target domain using the training loss and a reverse Kullback–Leibler (KL) divergence between the source and target distributions. There are existing bounds of the target loss in the literature Ben-David et al. (2010), however, these analyses focus mostly on the case of binary classification and the bounds use a total variation distance or a -divergence between the distributions, which are not easy to estimate in practice (for example, Ajakan et al. (2014) require an adversarial network to estimate the -divergence). In this paper, we show that with a probabilistic representation network, we can estimate the KL divergence easily using samples, leading to an alignment method that requires virtually no additional computation nor a minimax objective. Therefore, our training procedure is simple and stable in practice. Moreover, the reverse KL has the zero-forcing effect Minka and others (2005), which is very effective to alleviate the out-of-distribution problem in practice. This can be explained as follows: the out-of-distribution problem arises when the classifier faces a new representation at test time that is in a (near) zero mass region of the source representation distribution (and thus it never faced before). The reverse KL tends to force the target representation distribution to have (near) zero mass wherever the source distribution has (near) zero mass (this is the zero-forcing property), which helps the classifier avoid out-of-distribution data. The reverse KL also has the mode-seeking effect Minka and others (2005) which allows for a more flexible alignment of the representation (to one or some of the modes of the source domain). For example, consider the classification problem of buildings (houses, hotels, etc.) where source images are collected from urban and remote areas of a country (two modes); while the target images are collected from urban areas but from a different country. Ideally, we want to match the representation distribution of the target domain to that of the first mode of the source domain since they are both from urban areas. The reverse KL allows this flexible alignment (as it results in a relatively small value of the reverse KL) due to its mode-seeking property. Meanwhile, other distance metrics aim to match the whole source and target representation distribution, which might collapse the two modes of the source domains.

Our contributions in this work are:

  • We construct a generalization bound of the test loss in the domain adaptation problem using the reverse KL divergence.

  • We propose to reduce the generalization bound by minimizing the above KL term. Furthermore, we show that with a probabilistic representation, the KL term can be estimated easily using minibatches, without any additional computation or a minimax objective as opposed to most existing works.

  • We conduct extensive experiments and show that our method significantly outperforms relevant baselines, namely ERM Bousquet et al. (2003), DANN Ganin et al. (2016), MMD Li et al. (2018), CORAL Sun and Saenko (2016) and WD Shen et al. (2018). We empirically show that the reverse KL divergence is a very effective distance metric for representations since it is very stable and efficient to compute in practice.

2 Related Work

Generalization bound for the distribution shift problem

There exist works studying bounds for the distribution shift problem in the literature Ben-David et al. (2010); Mansour et al. (2009)

. However, their analyses are limited to the case of binary labels for classification. Moreover, these works assume deterministic labeling functions (with a L1 or L2 distance as the loss between them), which is not true for most datasets in practice. Therefore, their analyses cannot be generalized to the general case of supervised learning. The differences between our bound and theirs are as follows. First of all, our bound works for the general case of supervised learning: it works for both the classification (including multiclass classification) and regression problems, it makes no assumptions about the labeling mechanism (can be probabilistic or deterministic), and it works for virtually all predictive distributions commonly used in practice. Secondly, our bound uses a different divergence metric, namely KL, which is easier to estimate in practice compared to total variation or

-divergence. We provide a brief review of the above bounds and discuss their differences to ours in more detail in our appendix. Some specific cases of distribution shift have also been studied. For example, Cortes et al. (2010) and Johansson et al. (2019) study the generalization bound for the covariate shift problem, i.e., but , where is the source distribution and is the target distribution. In contrast, Azizzadenesheli et al. (2019) provide a generalization bound for the label shift problem, i.e., but .

Domain adaptation

While the literature on the domain adaptation problem is vast, we cover the most closely related works to ours here. A common method for the domain adaptation problem is to align the marginal distribution of the representation between the source and target domains. DANN Ganin et al. (2016) employs a domain discriminator to classify the domain of a representation and maximizes its adversarial loss (a minimax game). WD Shen et al. (2018) uses a neural network function (which is 1-Lipschitz continuous) to calculate the Wasserstein distance between two distributions and minimizes it. This is also a minimax game since the Wasserstein distance is the supremum over the search space of . MMD Li et al. (2018) uses the maximum mean discrepancy to align the representation distribution. This method does not need a minimax objective; however, it requires additional computation of several Gaussian kernels. Finally, CORAL Sun and Saenko (2016)

matches the first two moments of the distribution; and while being a simple method, it fails to align more complex distributions. We consider these marginal alignment techniques our main baselines since our method falls into this category, and investigate the effectiveness of the reverse KL divergence as a distance metric between representations. Recently, more sophisticated alignment methods

Kang et al. (2019); Xu et al. (2019) have been proposed for the domain adaptation problem, which achieve state-of-the-art performance. Instead of simply aligning the marginal distribution of the representation, these methods minimize the intra-class distance of the representation across domains, and maximize the inter-class distance between them, using the MMD or L2 distance. However, they require pseudo labels for the target domain (obtained via clustering). Moreover, they are complimentary to our method, as we conjecture that our method can also be used in conjunction with these, leading to the same algorithms but with the KL distance instead of MMD or L2.

3 Approach

3.1 Problem Statement

In this paper, we consider one of the most common domain adaptation settings, which consists of a single-source domain with the joint data distribution and a target domain with the data distribution , where denotes the input sample and is the label. We assume that these two domains have the same support sets . Regarding the training process of the domain adaptation problem, we further denote a labeled dataset of size sampled from the source domain , where , and an unlabeled dataset of size from the target domain , where .

The goal of a typical domain adaptation framework is to train a model with the labeled dataset of the source domain together with the unlabeled dataset from the target domain, so that the model will perform decently in the target domain. Note that this is only effective if the labeling mechanism is not too different between the source and the target domains.

In the domain adaption problem, we expect the changes in the marginal distribution so that , or the conditional distribution so that , or both, which often makes the typical empirical risk minimization training procedure ineffective. This motivates a line of approaches that learn a representation of whose marginal and conditional distributions are more aligned across the domains and use it for the prediction task, aiming at a better generalization performance to the target domain.

The general representation learning framework aims to learn a representation from with the mapping , which can be deterministic or probabilistic. That latent presentation is expected to contain the label-related information; and is then used to predict the label (by a classifier). Note that since the source and target domains have the same support set for and share the representation mapping , they also have the same support set for , denoted by . Given the representation , we learn a classifier to predict throuh the predictive distribution that is an approximation of the ground truth conditional distribution . During training, the representation network and the classifier are trained jointly on the source domain and we “hope” that they can generalize to the target domain, meaning that both and are kept unchanged between the two domains (this will be formalized below). The graphical model of that representation learning process is represented in Figure 1. In this paper, we consider a probabilistic representation mapping; specifically, the representation network will output and and , where

denotes a Gaussian distribution.

Figure 1: Graphical model. Note that the distribution (green edge), corresponding to our representation network, is shared between the source and target domains.

The joint distributions of

for the source and target domains can be represented as follows

(1)

and we define the predictive distribution of given as

(2)
Remark 1.

On the inference complexity of a probabilistic representation.

Using a probabilistic representation, we need to sample multiple from to estimate Eq. 2 during test time. However, this is not a big issue for the representation learning framework, since we only need to run the representation network (which is usually deep) once to get a distribution of . After sampling multiple from that distribution, we only need to rerun the classifier , which is usually a small network (e.g., often contains one layer). Furthermore, we can also run (a small network) in parallel for multiple to reduce inference time if necessary.

During training, we usually sample a single from for each input . The training objective is

(3)
(4)

where is the loss of a “data point”

. For common choices of the predictive distribution in the classification and regression problems, this is a non-negative quantity. For example, for a classification problem with a categorical predictive distribution, this becomes the cross-entropy loss, while for a regression problem with a Gaussian predictive distribution (with a fixed variance), it becomes the squared error (with an additive constant).

Minimizing will enforce .

We consider the below two assumptions of the representation on the source domain:

Assumption 1.

, where is the mutual information term, calculated on the source domain. In particular:

(5)

This is often referred to as the “sufficiency assumption” since it indicates that the representation has the same information about the label as the original input , and is sufficient for this prediction task (in the source domain). Note that the data processing inequality indicates that , so here we assume that contains maximum information about .

Remark 2.

Assumption 1 is an optimization goal of the training process on the source domain.

In particular, (with an additive constant) is an upper bound of , which is an upper bound of . Thus, minimizing will enforce to be equal to . For a more detailed discussion of this, please refer to, for example, Alemi et al. (2016).

Assumption 2.

When this assumption holds, the predictive distribution in Eq. 2 will approximate , as long as approximates .

Remark 3.

Assumption 2 is also an optimization goal of the training process on the source domain.

This is because is an upper bound of , which is an upper bound of . Thus, minimizing will enforce to be equal to . Therefore, .

These two assumptions ensure that our network has good performance on the source domain.

Note also that we only make the above two assumptions about the source domain, where we can enforce them through the training process. We do not make these assumptions about the target domain, since we have no access to the full target distribution. These two assumptions will also be used to prove our later theoretical result (Proposition 2).

3.2 KL Guided Domain Adaptation

Now we will consider the test loss in the domain adaptation problem, and how we can reduce it. The test loss (of the target domain) is:

(6)
(7)
(8)

Note that if the representation is invariant (both marginally and conditionally), then and Eq. 8 becomes , and we have a perfect generalization between the source domain and the target domain. However, there is no way to guarantee the invariance, since we do not know the target domain and the target data distribution. In that case, we introduce the following proposition that ensures a generalization bound of the test loss based on the training loss and the KL divergence:

Proposition 1.

If the loss is bounded by 111

In the classification problem, we can enforce this quite easily by augmenting the output softmax of the classifier so that each class probability is always at least

. For example, if we choose , and if the output softmax is , we can augment it into , where and is the number of classes. This ensures the bound for the loss of a datapoint, while remaining the output prediction class. , we have:

(9)
(10)
Proof.

provided in the appendix. ∎

This bound is similar to other bounds in the literature (e.g., Ben-David et al. (2010)) in the sense that it also contains the training loss, a marginal misalignment term and a conditional misalignment term ( and respectively in our case). However, Ben-David et al. (2010) consider a binary classification problem and their bounds are only practical for a deterministic labeling function; while our bound works for the general case of supervised learning with any labeling mechanism. For a brief review of these bounds and a detailed discussion about their differences to ours, please refer to our appendix. Note that the bound in Proposition 1 is also true when applying to the input space directly (e.g., replacing with ). However, we are more interested in the bound in the representation space, since we can reduce it by regularizing the KL term.

To reduce the generalization gap, we want to be close to . Aligning the marginal distribution (i.e., ) helps the classifier network avoid out-of-distribution data since the target representations it faces at test time belong to the source representation distribution which it was trained on; while aligning the conditional distribution () makes sure the classifier gives more accurate predictions on the target domain since was trained to approximate . In the domain adaptation problem, since we only have the unlabeled data from the target domain, we often align the marginal distribution of only. However, one problem is that the conditional misalignment also depends on the representation , and when learning a representation that aligns the marginal, we might accidentally increase at the same time, leading to a net increase in the above generalization bound. For example, what if (and is it possible that) the conditional misalignment increases to infinity while we learn a representation ?

Therefore, it is crucial that we can bound the above conditional misalignment. The below proposition handles this problem.

Proposition 2.

If Assumption 1 and 2 hold, and if , we have:

(11)
Proof.

provided in the appendix. ∎

This shows that the conditional misalignment in the representation space is bounded by the conditional misalignment in the input space. It then follows that:

(12)

Since is fixed (i.e., not dependent on the representation ), to reduce the generalization bound in Eq. 12, we can focus on minimizing , with the objective:

(13)

where is a hyper-parameter.

In practice, we found out that adding an auxiliary term (forward KL) with a small coefficient to the objective helps to align the distribution faster, leading to the objective:

(14)
Figure 2: Reverse KL allows a flexible alignment of the representation while still effectively preventing the out-of-distribution problem. (a) Source representation distribution (black). Consider the case where the data distribution of the source domain has two modes, then the representation distribution will likely also have two modes; and consider the case where the target distribution has only one mode. (b) An acceptable target representation distribution (green) that helps the classifier avoid the out-of-distribution problem. Reverse KL allows for this type of flexible alignment (match to one/some of the modes) due to its mode-seeking nature. (c) A problematic target representation distribution (red), since the classification network will face out-of-distribution data at test time, in the area between the two modes. Reverse KL will prevent this due to its zero-forcing nature.
Discussion on the use of reverse KL:

Our derivation leads to the reverse KL term as a regularizer of the distance between the two domains representations. We argue that there are several reasons that make this a good choice as a distance metric between the source and target representation distributions. (1) First of all, as mentioned earlier, the KL term can be computed easily without any additional network or a minimax objective (details in Subsection 3.3). This leads to an efficient and stable training procedure, which often results in improved performance. (2) Secondly, the reverse KL has the zero-forcing/mode-seeking effect Minka and others (2005) that helps to alleviate the out-of-distribution problem. Specifically, the reverse KL forces the target representation distribution to have zero mass wherever the source distribution has zero mass (zero-forcing), thus preventing the out-of-distribution data at test time (Figure 2c). On the other hand, its mode-seeking nature allows flexible alignment of the representation. For example, consider the case where the source domain is a mixture of two components (Figure 2a, i.e., it has two modes), and the target distribution is close to one of the two components. Ideally, we want to learn a representation network that matches the representation of the target domain to that of the corresponding component on the source mixture (Figure 2b). This representation will still perform well at test time since we would not have the out-of-distribution problem (the classification network is already trained on this mode of the source distribution). This flexible alignment (to one or some of the modes) is accepted by the reverse KL since it leads to a relatively small reverse KL value. Meanwhile, other distance metrics such as DANN, MMD, CORAL and WD aim to match the representation distribution of the target domain and that of the whole source domain together, which could compress the representation too much, negatively affecting its expressive power. For instance, in the above example, trying to match the whole distribution of source and target domains based on other distance metrics might force the two modes of the source domain to collapse. The flexible alignment of the reverse KL (while still being very effective to prevent out-of-distribution data) might be beneficial in some practical cases.

3.3 Optimization

In practice, we estimate Eq. 14 using minibatches. In particular, given a labelled minibatch of the source domain and an unlabelled one of the target domain, and a single sampled representation for each : and

, we can get an unbiased estimator of the objective 

14 as follows:

(15)

However, it still requires knowing and to compute Eq. 15. We also use the minibatch to approximate these quantities:

(16)

Intuitively, we use a minibatch of data to construct a distribution of the representation (which is a mixture of components), and match that distribution for the two domains with the KL divergence.

Although the estimator in Eq. 15 is unbiased, the approximations in Eq. 16 will introduce some bias into our estimator. However, the estimator is still consistent (i.e., it becomes exact when ). Therefore, we conjecture that the batch size might have an effect on the performance of the model. We verify this observation with an ablation study in Section 4.

As mentioned earlier, we use a Gaussian distribution with a diagonal covariance matrix for the representation , and employ the reparameterization trick Kingma and Welling (2013) to sample .

4 Experiments

4.1 Datasets

RotatedMNIST

consists of 70,000 MNIST

LeCun et al. (2010) images that are divided into six domains, each with 11,666 images. The images in each domain are rotated counter-clockwise by and respectively. We denote the six domains as and . We use as the source domain, and perform five experiments, each with as the target domain. The task is classification of the ten digit labels.

PACS Li et al. (2017)

contains 9,991 images from four different domains: art painting, cartoon, photo, sketch. The task is classification with seven classes. We consider all possible ordered pairs of domains as a source–target pair, leading to 12 experiments in total.

4.2 Baselines

We consider all common marginal alignment methods for domain adaptation as our baselines, including DANN Ganin et al. (2016), MMD Li et al. (2018), CORAL Sun and Saenko (2016) and WD Shen et al. (2018). We also consider ERM Bousquet et al. (2003) (empirical risk minimization) and its variant ERM (prob) (same as ERM but with the probabilistic representation network used in our model). For ERM, DANN, MMD and CORAL, we follow the implementation by Gulrajani and Lopez-Paz (2020)

; while for ERM (prob) and WD, we use our own implementation in Pytorch

Paszke et al. (2019). For the full description of these baselines, please refer to our appendix.

4.3 Experimental Setting

In each experiment, we split both the source and the target data into two portions: 80% and 20%. We use 80% of the source domain data and 80% of the target domain data (without the labels) as the training data. We use the remaining 20% of the source data as the validation set, and the remaining 20% of the target domain data as the test set. Note that we do not use the labeled data from the target domain during training or validation. This evaluation protocol is recommended by Gulrajani and Lopez-Paz (2020).

For the MNIST experiment, we use a simple convolutional neural network with four 3

3 convolutional layers (followed by an average pooling layer) as the representation network. For PACS, we use a Resnet18 as the representation network. Only the last layer of the representation network differs for a deterministic representation (ERM, DANN, CORAL, MMD, WD) and a probabilistic one (ERM (prob) and KL (ours)). For a representation of size , the last layer’s dimension of a deterministic representation network is , while that of a probabilistic network is ( for and for ).

We train each model for 100 epochs. To avoid hyperparameter bias, we tune the hyperparameters (learning rate, regularizer coefficients, weight decay, representation dimension and dropout rate) for each method and dataset independently. Following

Gulrajani and Lopez-Paz (2020), we perform a random search Bergstra and Bengio (2012) of 20 sets of hyperparameters over a predefined grid. We re-run each set of hyperparameters three times. This is an extensive set of experiments, and we have run thousands of models for the RotatedMNIST and PACS experiments (3 runs 20 sets of hyperparameters (12+5) experiments for each baseline). We train all models on an NVIDIA Quadro RTX 6000 GPU.

For details about the network and the range of hyperparameters, please refer to our appendix and our source code.

4.4 Results

RotatedMNIST: Table 1 shows the results for the RotatedMNIST experiment. It is clear that in this experiment, aligning the representation between domains does help improve the generalization performance. Among the baselines (DANN, MMD, CORAL, WD), MMD performs the best, which we attribute to the fact that it does not use a minimax objective, leading to more stable optimization. Meanwhile, CORAL performs the worst, since it only matches the first two moments of the distributions and might fail to align complex distributions. Our method, KL, largely outperforms the baselines, indicating its effectiveness.

Target Domain
Model Average
ERM 97.5±0.2 84.1±0.8 53.9±0.7 34.2±0.4 22.3±0.5 58.4
ERM (prob) 96.8±0.3 83.2±1.6 51.3±0.9 31.4±1.1 20.7±0.7 56.7
DANN 97.3±0.4 90.6±1.1 68.7±4.2 30.8±0.6 19.0±0.6 61.3
MMD 97.5±0.1 95.3±0.4 73.6±2.1 44.2±1.8 32.1±2.1 68.6
CORAL 97.1±0.3 82.3±0.3 56.0±2.4 30.8±0.2 27.1±1.7 58.7
WD 96.7±0.3 93.1±1.2 64.1±3.3 41.4±7.6 27.6±2.0 64.6
KL (ours) 97.5±0.5 96.6±0.4 92.0±0.4 57.8±9.7 58.3±4.2 80.1
Table 1: Rotated MNIST experiments with as the source domain.

PACS: Table 2 presents the results for PACS, which is a challenging real-world dataset for domain adaptaion/generalization. In this dataset, our model outperforms the ERM baselines by roughly 9% on average, indicating the effectiveness of our representation-alignment technique. Our method is the best performer (with a large margin) on 8 out of 12 experiments, showing a clear benefit over other representation alignment techniques. Together with our method, MMD again performs the best among the representation-alignment baselines (DANN, MMD, CORAL and WD), confirming that a stable training procedure (with no minimax objectives in MMD and our model) is important and often leads to better results. It is also worth noting that our model still outperforms MMD despite being less computationally expensive (in this implementation, MMD needs to compute seven Gaussian kernels for each of three pairs of representation sets in each minibatch).

It is interesting that the ERM baselines perform the best in some experiments (e.g., S  C, S  P). This result also agrees with the one observed in Gulrajani and Lopez-Paz (2020) that domain generalization/adaptation techniques might have negative effects when applied unsuccessfully. It should be noted that the S (sketch) domain is undoubtedly the most different compared to others (only black sketch on a white background while other domains have colors), which might explain the difficulty when learning to transfer between domains.

Model
Experiments ERM ERM (prob) DANN MMD CORAL WD KL (ours)
A C 66.1±1.3 63.5±0.8 71.0±3.2 79.5±0.4 62.7±10.4 76.2±0.9 73.1±3.4
A P 94.3±0.6 93.5±1.3 94.5±0.5 94.5±1.1 86.3±6.8 92.4±1.3 95.4±1.2
A S 53.6±0.8 60.9±3.5 58.6±12.8 62.1±2.0 46.2±3.5 53.9±2.7 67.4±1.9
C A 69.7±1.1 70.8±2.3 76.4±1.7 79.5±3.0 75.9±0.9 69.0±2.1 83.3±1.1
C P 82.0±0.9 81.5±2.1 78.6±3.4 80.8±2.3 78.3±3.6 72.9±8.6 83.1±7.4
C S 72.2±1.4 70.4±1.5 76.1±1.0 74.1±1.3 56.9±11.0 48.7±6.1 68.2±0.5
P A 65.7±2.3 63.3±1.2 68.0±2.7 67.7±1.8 70.0±1.5 62.6±1.5 75.5±2.5
P C 29.1±1.9 27.2±3.3 50.7±5.0 47.4±0.8 47.5±8.6 56.1±1.4 67.7±1.2
P S 38.0±1.0 35.9±2.3 29.3±9.8 59.7±4.8 15.8±5.3 22.3±15.0 64.5±2.1
S A 41.3±6.5 40.9±3.9 39.2±3.5 40.0±3.3 39.1±4.8 36.1±9.5 48.2±2.4
S C 66.7±1.0 67.9±1.4 64.3±2.0 65.7±2.3 59.9±1.5 60.5±2.0 63.5±0.4
S P 49.3±3.3 46.0±4.7 44.3±4.0 45.1±0.9 37.4±2.7 38.5±5.6 39.1±3.4
Average 60.6 60.2 62.6 66.3 56.3 57.4 69.1
Table 2: PACS experiments.

4.5 Ablation Study

In this subsection, we conduct an ablation study to investigate the effect of the batch size on our model’s performance. Table 3 shows the performance of our method on the RotatedMNIST dataset, with as the source domain and

as the target domain and with various choices of the batch size. As expected, our model’s performance tends to benefit from a bigger batch size, since it would alleviate the bias of our objective estimator. We therefore recommend increasing the batchsize whenever possible. However, even with a batchsize of 64 (which is common for deep learning), the model still performs reasonably well and significantly outperforms other baselines.

Batch size 256 128 64 32
KL (ours) 92.0±0.4 91.8±1.0 91.3±1.2 83.9±3.5
Table 3: Ablation study: Rotated MNIST experiments with source and target.

5 Conclusion

In conclusion, in this paper we derive a generalization bound of the target loss in the domain adaptation problem using the reverse KL divergence. We then show that with a probabilistic representation, the KL divergence can easily be estimated using Monte Carlo (minibatch) samples, without any additional computation or adversarial objective. By minimizing the KL divergence, we can reduce the generalization bound and have a better guarantee about the test loss. We also empirically show that our method outperforms relevant baselines with large margins, which we attribute to its simple and stable training procedure and the mode-seeking/zero-forcing nature of the reverse KL. We conclude that KL divergence is very effective as a distance metric between representations. In general, a limitation of marginal alignment methods (ours included) is that when the conditional distribution changes significantly from the source domain to the target domain, aligning the marginal would not help the target domain’s performance. This is also reflected in our generalization bound. For future work, we would want to investigate the use of KL divergence in other types of alignment. For example, we can follow the algorithm in Kang et al. (2019) to minimize the intra-class distance of the representation across domains and maximize the inter-class distance between them, but using the KL divergence instead of MMD as the distance metric. Another direction would be using KL divergence to align the conditional distribution across domains in a multi-source setting.

References

  • H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, and M. Marchand (2014) Domain-adversarial neural networks. arXiv preprint arXiv:1412.4446. Cited by: §1.
  • A. A. Alemi, I. Fischer, J. V. Dillon, and K. Murphy (2016) Deep variational information bottleneck. arXiv preprint arXiv:1612.00410. Cited by: §3.1.
  • K. Azizzadenesheli, A. Liu, F. Yang, and A. Anandkumar (2019) Regularized learning for domain adaptation under label shifts. arXiv preprint arXiv:1903.09734. Cited by: §2.
  • S. Ben-David, J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. W. Vaughan (2010) A theory of learning from different domains. Machine learning 79 (1), pp. 151–175. Cited by: §B.1, §B.1, §B.1, §B.1, §B.1, §B.1, §B.1, §B.2, §B.2, §B.2, §1, §2, §3.2.
  • J. Bergstra and Y. Bengio (2012) Random search for hyper-parameter optimization.. Journal of machine learning research 13 (2). Cited by: §4.3.
  • O. Bousquet, S. Boucheron, and G. Lugosi (2003)

    Introduction to statistical learning theory

    .
    In Summer School on Machine Learning, pp. 169–207. Cited by: §C.1, 3rd item, §4.2.
  • R. T. d. Combes, H. Zhao, Y. Wang, and G. Gordon (2020) Domain adaptation with conditional distribution matching and generalized label shift. arXiv preprint arXiv:2003.04475. Cited by: §1.
  • C. Cortes, Y. Mansour, and M. Mohri (2010) Learning bounds for importance weighting.. In Nips, Vol. 10, pp. 442–450. Cited by: §2.
  • Y. Ganin, E. Ustinova, H. Ajakan, P. Germain, H. Larochelle, F. Laviolette, M. Marchand, and V. Lempitsky (2016) Domain-adversarial training of neural networks. The journal of machine learning research 17 (1), pp. 2096–2030. Cited by: §C.1, §C.3.1, §C.3.2, 3rd item, §1, §2, §4.2.
  • M. Ghifary, W. B. Kleijn, M. Zhang, and D. Balduzzi (2015)

    Domain generalization for object recognition with multi-task autoencoders

    .
    In

    Proceedings of the IEEE International Conference on Computer Vision

    ,
    pp. 2551–2559. Cited by: §1.
  • I. Goodfellow (2016)

    Nips 2016 tutorial: generative adversarial networks

    .
    arXiv preprint arXiv:1701.00160. Cited by: §1.
  • I. Gulrajani and D. Lopez-Paz (2020) In search of lost domain generalization. arXiv preprint arXiv:2007.01434. Cited by: §C.2, §4.2, §4.3, §4.3, §4.4.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    ,
    pp. 770–778. Cited by: §1.
  • F. D. Johansson, D. Sontag, and R. Ranganath (2019) Support and invertibility in domain-invariant representations. In

    The 22nd International Conference on Artificial Intelligence and Statistics

    ,
    pp. 527–536. Cited by: §2.
  • G. Kang, L. Jiang, Y. Yang, and A. G. Hauptmann (2019) Contrastive adaptation network for unsupervised domain adaptation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4893–4902. Cited by: §2, §5.
  • A. Khosla, T. Zhou, T. Malisiewicz, A. A. Efros, and A. Torralba (2012) Undoing the damage of dataset bias. In European Conference on Computer Vision, pp. 158–171. Cited by: §1.
  • D. P. Kingma and J. Ba (2014) Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980. Cited by: §C.3.
  • D. P. Kingma and M. Welling (2013) Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114. Cited by: §3.3.
  • N. Kodali, J. Abernethy, J. Hays, and Z. Kira (2017) On convergence and stability of gans. arXiv preprint arXiv:1705.07215. Cited by: §1.
  • Y. LeCun, C. Cortes, and C. Burges (2010) MNIST handwritten digit database. ATT Labs [Online]. Available: http://yann.lecun.com/exdb/mnist 2. Cited by: §4.1.
  • D. Li, Y. Yang, Y. Song, and T. Hospedales (2017) Deeper, broader and artier domain generalization. In International Conference on Computer Vision, Cited by: §4.1.
  • H. Li, S. J. Pan, S. Wang, and A. C. Kot (2018) Domain generalization with adversarial feature learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5400–5409. Cited by: §C.1, §C.3.1, §C.3.2, 3rd item, §1, §2, §4.2.
  • Y. Mansour, M. Mohri, and A. Rostamizadeh (2009) Domain adaptation: learning bounds and algorithms. arXiv preprint arXiv:0902.3430. Cited by: §B.2, §B.2, §B.2, §B.2, §2.
  • T. Minka et al. (2005) Divergence measures and message passing. Technical report Citeseer. Cited by: §1, §3.2.
  • K. Muandet, D. Balduzzi, and B. Schölkopf (2013) Domain generalization via invariant feature representation. In International Conference on Machine Learning, pp. 10–18. Cited by: §1.
  • A. T. Nguyen, T. Tran, Y. Gal, and A. G. Baydin (2021) Domain invariant representation learning with domain density transformations. arXiv preprint arXiv:2102.05082. Cited by: §1.
  • A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, A. Desmaison, A. Kopf, E. Yang, Z. DeVito, M. Raison, A. Tejani, S. Chilamkurthy, B. Steiner, L. Fang, J. Bai, and S. Chintala (2019) PyTorch: an imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems 32, H. Wallach, H. Larochelle, A. Beygelzimer, F. dAlché-Buc, E. Fox, and R. Garnett (Eds.), pp. 8024–8035. Cited by: §4.2.
  • J. Shen, Y. Qu, W. Zhang, and Y. Yu (2018) Wasserstein distance guided representation learning for domain adaptation. In Proceedings of the AAAI Conference on Artificial Intelligence, Vol. 32. Cited by: §C.1, §C.3.1, §C.3.2, 3rd item, §1, §2, §4.2.
  • B. Sun and K. Saenko (2016) Deep coral: correlation alignment for deep domain adaptation. In European conference on computer vision, pp. 443–450. Cited by: §C.1, §C.3.1, §C.3.2, 3rd item, §2, §4.2.
  • A. K. Tanwani (2020) Domain-invariant representation learning for sim-to-real transfer. arXiv preprint arXiv:2011.07589. Cited by: §1, §1.
  • A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser, and I. Polosukhin (2017) Attention is all you need. arXiv preprint arXiv:1706.03762. Cited by: §1.
  • X. Xu, X. Zhou, R. Venkatesan, G. Swaminathan, and O. Majumder (2019) D-sne: domain adaptation using stochastic neighborhood embedding. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2497–2506. Cited by: §2.
  • Y. Zhang, T. Liu, M. Long, and M. Jordan (2019) Bridging theory and algorithm for domain adaptation. In International Conference on Machine Learning, pp. 7404–7413. Cited by: §1.
  • H. Zhao, R. T. Des Combes, K. Zhang, and G. Gordon (2019) On learning invariant representations for domain adaptation. In International Conference on Machine Learning, pp. 7523–7532. Cited by: §1.

Appendix A Proofs

For the following proofs, we treat the variables as continuous variables and always use the integral. If one or some of the variables are discrete, it is straight-forward to replace the corresponding integral(s) with summation sign(s) and the proofs still hold.

a.1 Proposition 1

Proof.

We have:

(17)
(18)
(19)
(20)

Let and , using the fact that we have:

(21)
(22)
(23)
(24)
(25)
(26)

where is the absolute value.

Here, is also called the total variation of the two distributions and .

Note that:

(27)
(28)
(29)
(30)
(31)

Therefore:

(32)
(33)

Using the Pinsker’s inequality, we have:

(34)

Therefore, we finally have:

(35)
(36)

Which concludes our proof.

Also note that the KL divergence between and can further be decomposed into the marginal misalignment and conditional misalignment as follow:

(37)
(38)
(39)
(40)
(41)

a.2 Proposition 2

Proof.

According to Assumption 1, we have:

(42)
(43)
(44)
(45)
(46)
(47)
(48)

According to Assumption 2, we have:

(49)
(50)
(51)

Since , there exists such that . Therefore:

(52)
(53)
(54)
(55)

Therefore:

(56)
(57)
(58)

We have:

(59)
(60)

Using Eq 58, we now only need to prove that:

(61)
(62)
(63)
(64)
(65)

Appendix B Review of existing generalization bounds

There have been several works studying the generalization bounds of the Domain Adaptation problem. We briefly review the most important and common ones here with a discussion about their differences to our proposed bound.

b.1 Ben-David et al. [2010]

Ben-David et al. [2010] consider a binary classification problem. Let be the input with the support set and be the binary label with the support set . Consider a source domain with a distribution over the input and the true labeling function ; and similarly a target domain with a distribution over the input and the true labeling function . Note that the authors claim that this labeling function can be probabilistic; in that case, denoting the probability. However, we argue that this probabilistic setting is impractical since we would not know that true underlying function in order to calculate the training loss in practice). Therefore, we found that the bound is only practical for the case of a deterministic labeling mechanism.

The error of the classifier , which is also a deterministc labeling function, on the source domain is:

(66)

and similarly for the target domain:

(67)

Here is the absolute value, which mean the loss of a data point is the L1 distance of the labels.

Consider a hypothesis space and let a classifier be any function from that space. The first theorem in Ben-David et al. [2010] offers a bound of the target loss based on the source loss , and the total variation between