Deep neural networks have gained great success on a wide range of tasks such as visual recognition and machine translation(lecun2015deep). They usually require a large number of labeled data that can be prohibitively expensive to collect, and even with sufficient supervision their performance can still be poor when being generalized to a new environment. The problem of discrepancy between training and testing data distribution is commonly referred to as domain shift or covariant shift (shimodaira2000improving). To alleviate the effect of such shift, domain adaptation sets out to obtain a model trained in a label-rich source domain to generalize well in an unlabeled target domain (pan2010survey). Domain adaptation has benefited various applications in many practical scenarios, including but not limited to object detection under challenging conditions (chen2018domain), cost-effective learning using only synthetic data to generalize to real-world imagery (vazquez2013virtual), etc.
Prevailing methods for unsupervised domain adaptation (UDA) are mostly based on domain alignment which aims to learn domain-invariant features by reducing the distribution discrepancy between the source and target domain using some pre-defined metrics such as maximum mean discrepancy (gretton2007kernel; gretton2012optimal). Recently, ganin15 proposed to achieve domain alignment by domain-adversarial training (DAT) that reverses the gradients of a domain classifier to maximize domain confusion. Having yielded remarkable performance gain, DAT was employed in many subsequent UDA methods (long2017conditional; chen2019transferability; liu2019transferable).
Nevertheless, DAT with gradient reverse layers still face three critical restrictions when applying it to practical scenarios. (1) DAT cannot continuously provide effective gradients for learning domain-invariant representations. The reason is that the binary domain classifier has high-capacity to discriminate two domains and thus overwhelms adversarial training, which is usually solved by manually adjusting the weights of adversarial loss according to specific tasks such as (shu2018dirt). (2) DAT cannot deal with pixel-level domain shift that are frequently encountered in visual tasks (hoffman18a). (3) The domain-invariant features learned by DAT are only based on intuition and learning theory (ben2010theory) but difficult to interpret, which impedes the investigation of the underlying mechanism of adversarial domain adaptation.
To overcome the aforementioned difficulties, we propose a novel adversarial approach, namely Max-margin Domain-Adversarial Training (MDAT), to realize stable and comprehensive (i.e. both feature-level and pixel-level) domain alignment. MDAT works based on a carefully-designed Adversarial Reconstruction Network (ARN). Specifically, ARN consists of a shared feature extractor, a label predictor, and a reconstruction network (i.e. decoder) that serves as a domain classifier. MDAT enables an adversarial game between the feature extractor and the decoder. The decoder focuses on reconstructing features on source domain and pushing target features away from a margin, while the feature extractor aims to fool the decoder by generating target features that can be reconstructed. In this adversarial way, three critical issues are subtly solved: (1) the max-margin loss reduces the discriminative capacity of domain classifier, balancing and stabilizing domain-adversarial training; (2) without involving any new network structures, MDAT achieves both pixel-level and feature-level domain alignment; (3) reconstructing adapted features to images reveals how the source and target domains are aligned by adversarial training. We evaluate ARN with MDAT on both visual and non-visual UDA benchmarks. It shows more training procedure and achieves significant improvement to DAT on all tasks with pixel-level or higher-level domain shift
. We also observe that it is insensitive to the choices of hyperparameters and as such is favorable for replication in practice. In principle, our approach is generic and can be used to enhance any domain adaptation methods that leverage domain alignment as an ingredient.
2 Related Work
Domain adaptation aims to transfer knowledge from one domain to another. ben2010theory provide an upper bound of the test error on the target domain in terms of the source error and the -distance. As the source error is stationary for a fixed model, the goal of most UDA methods is to minimize the -distance by reducing some metrics such as Maximum Mean Discrepancy (MMD) (TzengHZSD14; Longicml15) and CORAL (sun2016deep). Inspired by Generative Adversarial Networks (GAN) (goodfellow2014generative), ganin15 proposed to learn domain-invariant features by Domain-Adversarial Training (DAT), which has inspired many UDA methods thereafter. For example, zhang2019bridging propose a new divergence for distribution comparison based on minimax optimization and wang2019negative discover that filtering our unrelated source samples helps avoid negative transfer in DAT. Adversarial Discriminative Domain Adaptation (ADDA) tried to fool the label classifier by adversarial training but not in an end-to-end manner. CyCADA (hoffman18a) and PixelDA (bousmalis2017unsupervised) leveraged GAN to conduct both feature-level and pixel-level domain adaptation, which yields significant improvement yet the network complexity is high. Recent works explore that DAT deteriorates feature learning, and hence they propose to overcome it by generating transferable examples (liu2019transferable) or involving extra regularizer to retain discriminability (chen2019transferability). These approaches can also be directly applied to MDAT for further enhancement.
Another line of approaches that are relevant to our method is reconstruction network (i.e.
decoder network), which enables unsupervised image-to-image translation by learning pixel-level features(zhu2017unpaired). In UDA, ghifary2016deep employed a decoder network for pixel-level adaptation, and Domain Separate Network (DSN) (bousmalis2016domain) further leveraged multiple decoder networks to learn domain-specific features. These approaches treat the decoder network as an independent component for augmented feature learning that is irrelevant to domain alignment (glorot2011domain). In this paper, we propose to innovatively utilize decoder network as domain classifier in MDAT which enables both feature-level and pixel-level domain alignment in a stable and straightforward fashion.
3 Problem Formulation
3.1 Problem Definition and Notations
In unsupervised domain adaptation, we assume that a model works with a labeled dataset and an unlabeled dataset . Let denote the labeled dataset of samples from the source domain, and the certain label belongs to the label space that is a finite set (). The other dataset has samples from the target domain but has no labels. We further assume that two domains have different distributions, i.e. and . In other words, there exist some domain shift (ben2010theory) between and . The ultimate goal is to learn a model that can predict the label given the target input .
3.2 Unbalanced Minimax Game in Domain-Adversarial Training
To achieve domain alignment, Domain-Adversarial Training (DAT) is a minimax game between a shared feature extractor for two domains and a domain classifier . The domain classifier is trained to determine whether the input sample belongs to the source or the target domain while the feature extractor learns to deceive the domain classifier, which is formulated as:
We usually utilize Convolutional Neural Network (CNN) as the feature extractor and fully connected layers (FC) as the domain classifier. Theoretically, DAT reduces the cross-domain discrepancy and helps learn domain-invariant representations(ganin2016domain). However, the training of DAT is rather unstable. Without sophisticated tuning of the hyper-parameters, DAT cannot often reach the convergence. Through empirical experiments, we observe that such instability is due to the imbalanced adversarial game between and . The binary domain discriminator
can easily achieve convergence with very high accuracy at an early training epoch, while it is much harder for the feature extractorto fool the domain discriminator and to simultaneously perform well on the source domain. In this sense, there is a strong likelihood that the domain classifier overwhelms DAT, and the only solution is to palliate the training of
by tuning the hyper-parameters according to different tasks. In our method, we restrict the capacity of the domain classifier so as to form a minimax game in a harmonious manner. Inspired by the max-margin loss in Support Vector Machine (SVM)(cristianini2000introduction) (i.e. hinge loss), if we push the source domain and the target domain away from a margin rather than as far as possible, then the training task of to fool becomes much easier. For a binary domain classifier, we define the margin loss as
where is the predicted domain label, is , is a positive margin and is the ground truth label for two domains (assuming for the source domain and for the target domain). Then we introduce our MDAT scheme based on an innovative network architecture.
3.3 Max-margin Domain-Adversarial Training
Besides the training instability issue, DAT also suffers from restrictive feature-level alignment – lack of pixel-level alignment. To realize stable and comprehensive domain alignment together, we first propose an Adversarial Reconstruction Network (ARN) and then elaborate MDAT.
As depicted in Figure 1, our model consists of three parts including a shared feature extractor for both domains, a label predictor and a reconstruction network . Let the feature extractor be a function parameterized by which maps an input sample x to a deep embedding z. Let the label predictor be a task-specific function parameterized by which maps an embedding z to a task-specific prediction . The reconstruction network is a decoding function parameterized by that maps an embedding z to its corresponding reconstruction .
The first learning objective for the feature extractor and the label predictor is to perform well in the source domain. For a supervised -way classification problem, it is simply achieved by minimizing the negative log-likelihood of the ground truth class for each sample:
is the one-hot encoding of the class labeland the logarithm operation is conducted on the softmax predictions of the model.
The second objective is to render feature learning to be domain-invariant. This is motivated by the covariate shift assumption (shimodaira2000improving) that indicates if the feature distributions and are similar, the source label predictor can achieve a similar accuracy in the target domain. To this end, we first design a decoder network that serves as a domain classifier. In MDAT, we train the decoder network to only reconstruct the features in the source domain and to push the features in the target domain away from a margin. In this way, the decoder has the functionality of distinguishing the source domain from the target domain. The objective of training is formulated as
where is a positive margin and is the mean squared error (MSE) term for the reconstruction loss defined as
where denotes the squared -norm. Compared the normal binary domain classifier (e.g. fully connected layers), the decoder network is tailored as a smoothing domain discriminator by separating two domains from a specific margin rather than as far as possible.
Oppositely, to form an adversarial game, the feature extractor learns to deceive such that the learned target features are indistinguishable to the source ones, which is formulated by:
Then the whole learning procedure of ARN with MDAT can be formulated by:
where denotes the negative log-likelihood of the ground truth class for labeled sample and controls the interaction of the loss terms. In the following section, we derive an optimal solution of MDAT and provide theoretical justifications on how MDAT reduces the distribution discrepancy for UDA.
3.4 Optimal Solution of MDAT
Considering the adversarial game between a reconstruction network and a feature extractor (i.e. and in our network, respectively), we prove that if the feature extractor maps both source domain and target domain to a common feature space , the MDAT system reaches a Nash Equilibrium. This theoretically explains how MDAT enables the feature extractor to learn domain-invariant features. Similar to EBGAN (zhao2016energy), we assume and have infinite capacity. Denote as the MSE of the reconstruction network. We first define two objectives:
In MDAT, we train the feature extractor to minimize the quantity and train the reconstruction network to minimize the quantity . A Nash equilibrium of our system is a pair that satisfies:
If a feature extractor maps both source domain and target domain to a common feature space , the system reaches a Nash equilibrium and .
Proof. We first prove Eq.11:
As we know is monotonically increasing on , reaches its minimum when :
When , we expand in Eq.12:
As , we get .
It can be easily observed that the optimal solution of MDAT is a Nash equilibrium when the feature extractor maps two domains into a common feature space, i.e. aligning the distributions in the feature space.
3.5 Connection to Domain Adaptation Theories
We further investigate how the proposed method connects the learning theory of domain adaptation. The rationale behind domain alignment is motivated from the learning theory of non-conservative domain adaptation problem by Ben-David et al. (ben2010theory):
Let be the hypothesis space where . Let and be the two domains and their corresponding generalization error functions. The expected error for the target domain is upper bounded by
where the ideal risk , and
Theoretically, when we minimize the -distance, the upper bound of the expected error for the target domain is reduced accordingly. As derived in DAT (ganin15), assuming a family of domain classifiers to be rich enough to contain the symmetric difference hypothesis set of , such that where is XOR-function, the empirical -distance has an upper bound w.r.t. the optimal domain classifier :
where and denote the distributions of the source and target feature space and , respectively. Note that the MSE of plus a ceiling function is a form of domain classifier , i.e. for . It maps source samples to and target samples to which is exactly the upper bound in Eq.19. Hence, our reconstruction network maximizes the domain discrepancy with a margin and the feature extractor learns to minimize it adversarially.
|Train on target||96.5||99.4||99.4||91.3|
|[S] MMD (Longicml15)||81.1||-||71.1||88.0|
|[S] CORAL (sun2016deep)||80.7||-||63.1||85.2|
|[R] DRCN (ghifary2016deep)||91.8||73.7||82.0||87.5|
|[R] DSN (bousmalis2016domain)||91.3||-||82.7||91.2|
|[A] DANN (DAT) (ganin2016domain)||85.1||73.0||74.7||90.3|
|[A] ADDA (tzeng2017adversarial)||89.4||90.1||76.0||-|
|[A] CDAN (long2017conditional)||93.9||96.9||88.5||-|
|[A] CyCADA (hoffman18a)||95.6||96.5||90.4||-|
|[A] BSP+DANN (chen2019transferability)||94.5||97.7||89.4||-|
|[A] MCD (saito2018maximum)||96.5||94.1||96.2||-|
|[A] CADA (zou2019consensus)||96.4||97.0||90.9||-|
|ARN w.o. MDAT||93.10.3||76.51.2||67.40.9||86.80.5|
|ARN with MDAT (proposed)||98.60.3||98.40.1||97.40.3||92.00.2|
) domain adaptation approaches. We repeated each experiment for 3 times and report the average and standard deviation (std) of the test accuracy in the target domain.
Compared with the conventional DAT-based methods that are usually based on a binary logistic network (ganin15), the proposed ARN with MDAT is more attractive and incorporates new merits conceptually and theoretically:
(1) Effective gradients and balanced adversarial training. Using the decoder as domain classifier with a margin loss to restrain its overwhelming capacity in adversarial training, the adversarial game can continuously provide effective gradients for training the feature extractor, leading to better alignment and balanced adversarial training. Moreover, through the experiments in Section 4, we discover that our method shows more stable training procedure and strong robustness to the hyper-parameters, i.e. and , greatly alleviating the parameters tuning for model selection.
(2) Richer information for comprehensive domain alignment. Rather than typical DAT that uses a bit of domain information, MDAT utilizes the reconstruction network as the domain classifier that captures more domain-specific and pixel-level features during the unsupervised reconstruction (bousmalis2016domain). Therefore, MDAT further helps address pixel-level domain shift apart from the feature-level shift, leading to comprehensive domain alignment in a straightforward manner.
(3) Feature interpretability for method validation. MDAT allows us to visualize the features by directly reconstructing target features to images by the decoder network. It is crucial to understand to what extent the features are aligned since this helps to reveal the underlying mechanism of adversarial domain adaptation. We interpret these adapted features in Section 4.3.
|[S] MMD (Longicml15)||61.20.5|
|[R] DRCN (ghifary2016deep)||69.30.3|
|[A] DANN (DAT) (ganin15)||68.20.2|
|[A] ADDA (tzeng2017adversarial)||71.50.3|
|[A] CADA (zou2019consensus)||88.80.1|
|ARN with MDAT||91.30.2|
We evaluate the proposed approach on several visual and non-visual UDA tasks with varying degrees of domain shift. Then detailed analyses are conducted w.r.t. toy dataset, parameter sensitivity, gradient and feature visualization. Dataset descriptions and implementation details are attached in the supplementary materials.
Digits (ganin2016domain). We utilize four digit datasets including MNIST, USPS, SVHN and Synthetic Digits (SYN) that form four transfer tasks: MNISTUSPS, USPSMNIST, SVHNMNIST and SYNSVHN.
Office-Home (venkateswara2017deep) is a challenging UDA dataset including 15,500 images from 65 categories. It comprises four extremely distinct domains: Artistic images (Ar), ClipArt (Cl), Product images (Pr), and Real-World images (Rw). We evaluate on all twelve transfer tasks.
WiFi Gesture Recognition (zou2019consensus) consists of six gestures recorded by Channel State Information (CSI) (xie2018precise). Each CSI sample is a 2D matrix that depicts the gesture with the surrounding layout environment. Thus, the CSI data collected in two environments forms two domains, which formulates a spatial adaptation problem.
We compare with state-of-the-art UDA methods that perform three ways of domain alignment. Specifically, MMD regularization (Longicml15) and CORAL (sun2016deep) are based on statistical distribution matching. DRCN (ghifary2016deep) and DSN (bousmalis2016domain) use the reconstruction network for UDA, while more prevailing UDA methods adopt domain-adversarial training including DANN (ganin15), ADDA (tzeng2017adversarial), CyCADA (hoffman18a), CDAN (long2017conditional), MCD (saito2018maximum), CADA (zou2019consensus), TransNorm (wang2019transferable) and BSP (chen2019transferability). The baseline results are reported from their original papers where available.
We used Pytorch to implement our model. For Digits dataset, we follow the same protocol in (hoffman18a) and the same network architecture of (ganin15). For Office-Home
, we adopt ResNet-50 pretrained on ImageNet as our backbone. According to the standard protocols in(Longicml15), we employ all the labeled source samples and unlabeled target samples for training. For WiFi Gesture Recognition data, we employ the modified LeNet and the standard protocol in (zou2019consensus). The designs of are the inverse of with pooling operation replaced by upsampling. We fix and in all the experiments, which are obtained on SVHNMNIST by Baysian optimization (malkomes2016bayesian). We adopt mini-batch SGD optimizer with momentum of 0.9 and the progressive training strategy in DANN (ganin15).
4.2 Overall Results
The classification accuracies on Digits are shown in Table 1. Our method outperforms all other methods on four transfer tasks. Specifically, for SVHNMNIST where severe pixel-level domain shift exists, our method significantly improves DANN by 22.7%, which justifies the efficacy of ARN for addressing pixel-level shift. Our method also performs well when the target domain are quite small, achieving 98.6% accuracy on MNISTUSPS. In Table 2, our method improves the source-only model by 32.9% on WiFi spatial adaptation problem, which indicates that MDAT is also helpful for non-visual domain adaptation problem. Table 3 shows the performance on large-scale dataset and MDAT yields better performance against other domain alignment approaches.
|DANN (DAT) (ganin15)||45.6||59.3||70.1||47.0||58.5||60.9||46.1||43.7||68.5||63.2||51.8||76.8||57.6|
|ARN with MDAT (Proposed)||51.3||69.7||76.2||59.5||68.3||70.0||57.2||48.9||75.8||69.1||55.3||80.6||65.2|
Ablation study. To verify the contribution of the reconstruction network and MDAT, we discard the term in Eq.4, and evaluate the method, denoted as ARN w.o. MDAT in Table 1. Comparing it with source-only model, we can infer the improvement of reconstructing target samples. ARN w.o. MDAT improves tasks with low-level domain shift such as MNISTUSPS, which conforms with our discussion that unsupervised reconstruction is instrumental in learning pixel-level features. Comparing ARN w.o. MDAT with the original ARN, we can infer the contribution of MDAT. Table 1 shows that the MDAT achieves an impressive margin-of-improvement. For USPSMNIST and SVHNMNIST, the MDAT improves ARN w.o. MDAT by around 30%. It demonstrates that MDAT that helps learn domain-invariant features is the main reason for the tremendous improvement.
Toy dataset. We study the behavior of MDAT on a variant of inter-twinning moons 2D problem, where the target samples are rotated from the source samples. 300 samples are generated for each domain using scikit-learn (pedregosa2011scikit). The adaptation ability is investigated by comparing MDAT with DANN and source-only model. As shown in Figure 2, we visualize the changing boundaries during 10 epochs training. In Figure 2(a), the model is overfitting the source domain, and the decision boundary does not change. In Figure 2(b) and 2(c), both DANN and MDAT adapt the boundaries to the target samples, but MDAT shows faster and better adaptation during 10 epochs. Integrating the training procedure of SVHNMNIST in Figure 3(a), we justify that more effective gradients are provided by MDAT for better adaptation performance.
Gradients and stability analysis. We further study the training procedure of MDAT on SVHNMNIST w.r.t. loss and target accuracy in Figure 3(a) and 3(b), respectively. In Figure 3(a), ARN has steadily decreasing loss () for all , but the domain loss in DAT () becomes extremely small at the beginning. These observations conform with our intuition: the domain classifier in DAT is too strong to impede the adversarial training, while MDAT provides more effective gradients for training feature extractor by restricting the capacity of domain classifier. With effective gradients, the adversarial game is more balanced, which is validated in Figure 3(b) where the test accuracy of ARN is more stable than that of DAT across training epochs.
|Source Images||Target Images||R-Target Images|
Parameter sensitivity. We investigate the sensitivity of and on SVHNMNIST. In Table 4, the results show that ARN achieves good performance as . Even with larger , ARN is able to achieve convergence. In comparison, denoting as the weight of adversarial domain loss (), the DANN cannot converge when due to the imbalanced adversarial game between the overwhelming domain classifier and the feature extractor. For the sensitivity of , the accuracy of ARN exceeds 96.0% as . In Section 3.5, as , the decoder serves as a domain classifier. These analyses validate that ARN is more insensitive to the parameters than that of DANN. Even in the worst cases, ARN can always achieve convergence.
T-SNE embeddings. We analyze the performance of domain alignment for DANN (DAT) (ganin15) and ARN (MDAT) by plotting T-SNE embeddings of the features z on the task SVHNMNIST. In Figure 4(a), the source-only model obtains diverse embeddings for each category but the domains are not aligned. In Figure 4(b), the DANN aligns two domains but the decision boundaries of the classifier are vague. In Figure 4(c), the proposed ARN effectively aligns two domains for all categories and the classifier boundaries are much clearer.
Interpreting adapted features via reconstruction. One of the key advantages of ARN is that by visualizing the reconstructed target images we can infer how the features are domain-invariant. We reconstruct the MDAT features of the target domain and visualize them in Table 5. It is observed that the target features are reconstructed to source-like images by the decoder . As discussed before, intuitively, MDAT forces the target feature to mimic the source feature distribution, which conforms with the visualization. Similar to image-to-image translation, this indicates that our method conducts implicit feature-to-feature translation that transfers the target features to source-like features, and hence the features are domain-invariant.
We proposed a new domain alignment approach namely Max-margin Domain-adversarial Training (MDAT) and a MDAT-based deep neural network for unsupervised domain adaptation. The proposed method offers effective and stable gradients for feature learning via an adversarial game between the feature extractor and the reconstruction network. The theoretical analysis provides justifications on how it minimizes the distribution discrepancy. Extensive experiments demonstrate the effectiveness of our method and we further interpret the features by visualization that conforms with our insight. Potential evaluation on semi-supervised learning constitutes our future work.
Hyperparameter For all tasks, we simply use the same hyperparameters that are chosen from the sensitivity analysis. We use and , and we reckon that better results can be obtained by tuning the hyperparameters for specific tasks.
We have presented all the results of the sensitivity study in Section 4.3, and now we show their detailed training procedures in Figure 5(a) and 5(b). It is observed that the accuracy increases when drops or the margin increases. The reason is very simple: (1) when is too large, it affects the effect of supervised training on source domain; (2) when the margin is small, the divergence between source and target domain (i.e. -distance) cannot be measured well.
Here we provide more visualization of the reconstructed images of target samples. In Figure 6, the target samples are shown in the left column while their corresponding reconstructed samples are shown in the right. We can see that for low-level domain shift such as MNISTUSPS, the reconstructed target samples are very source-like while preserving their original shapes and skeletons. However, for larger domain shift in Figure 6(c) and 6(d), they are reconstructed to source-like same digits but simultaneously some noises are removed. Specifically, in Figure 6(d), we can see that one target sample (SVHN) may contain more than one digits that are noises for recognition. After reconstruction, only the right digits are reconstructed. Some target samples may suffer from terrible illumination conditions but their reconstructed digits are very clear, which is amazing.