Deep causal representation learning for unsupervised domain adaptation

10/28/2019 ∙ by Raha Moraffah, et al. ∙ 14

Studies show that the representations learned by deep neural networks can be transferred to similar prediction tasks in other domains for which we do not have enough labeled data. However, as we transition to higher layers in the model, the representations become more task-specific and less generalizable. Recent research on deep domain adaptation proposed to mitigate this problem by forcing the deep model to learn more transferable feature representations across domains. This is achieved by incorporating domain adaptation methods into deep learning pipeline. The majority of existing models learn the transferable feature representations which are highly correlated with the outcome. However, correlations are not always transferable. In this paper, we propose a novel deep causal representation learning framework for unsupervised domain adaptation, in which we propose to learn domain-invariant causal representations of the input from the source domain. We simulate a virtual target domain using reweighted samples from the source domain and estimate the causal effect of features on the outcomes. The extensive comparative study demonstrates the strengths of the proposed model for unsupervised domain adaptation via causal representations.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 6

page 8

page 11

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep neural networks have had great achievements in different areas such as image classification [1] and object detection [2]. These models usually require huge amounts of training data. However, collecting and annotating datasets are usually labor-intensive. Luckily, there are huge amounts of labeled data available from other domains that can be leveraged. However, distributions of the source (i.e., the learning model is trained on) and target (i.e., the dataset we wish to apply the trained model on) are usually different, which leads to the failure of those models that transfer the knowledge from source to target domain directly.

Even though deep neural nets can learn transferable representations, studies show that differences between distributions of source and target domains (domain shift) can affect the performance of these models and the representations become less transferable in higher layers of the network [3][4]. Moreover, in many real-world cases where no or few labeled data is available in the target domain, overfitting to the source distribution ensues if the model is trained in a supervised manner on data in both source and target domains [5]

. Various unsupervised deep domain adaptation methods are proposed to utilize labeled samples from source domain and unlabeled samples from the target domain, and learn a classifier which minimizes the prediction error in the target domain by embedding shallow domain adaptation methods into deep learning pipeline and learning representations that are both predictive and domain invariant

[6] [7] [8] [9]. However, these methods learn feature representations correlated with the outcome.

Figure 1: A causal diagram: and are causal and spurious features, confounders, and prediction outcome.

Since some correlations can be spurious and therefore not transferable, we propose to learn the causal feature representations for unsupervised domain adaptation. Causal feature representations of the data are those used to define the structure of the outcome variable rather than the context. For instance, consider a picture of a cat. Features such as eyes, ears and shape of the face which are related to the structure of a cat and thus are referred to as causal feature representations, and features such as background of the image are called context feature representations. Figure 1 illustrates one possible causal graph of the extracted features such as eyes, ears, background and the outcome variable (i.e., an indicator of whether the image is a cat or not). Causal features such as eyes and ears (i.e., ) have direct causal effect on the outcome whereas context features (i.e., ) are spurious features, i.e. the correlations between them and the outcome variable are due to the existence of , a confounder. The correlations due to confounding variables in the data are misleading and may not be transferable across different target domains.

The causal mechanism that maps the cause to the effect should not depend on the distribution of the cause [10] and causal features are naturally transferable across different domains [11]. For instance, in our example, if the model learns the features related to the structure of cats such as eyes, ears, whiskers and etc., instead of learning the context features, these causal features are more transferable across domains. This can be achieved by learning causal relationships between features and the outcome instead of their correlations.

In order to capture the causal relationships of the learned representations on the outcome, followng [12], our framework in Figure 1 simulates a virtual target domain on top of the representations learned from the source data by re-weighting the samples from the source domain. We show that in this simulated target domain, only representations that are the causes of the outcome variable () are extracted and all other correlations such as those between and ) are removed. Therefore, the model can learn the representations with highest causal effect on the outcome by measuring the correlations between the outcome and representations of the virtual domain. These representations are then used along with the causal mechanism to perform prediction in a target domain. Learning these weights are embedded into the pipeline of the deep neural net and occurs jointly with parameters of the deep model. This way, the model can learn the representations which are both predictive and invariant across different domains.

Our major contributions are summarized as follows:

  • We study a novel problem of learning causal representations for unsupervised domain adaptation;

  • We propose a general framework DCDAN to learn causal feature representations and causal mechanisms to make prediction in target domains and show that the learned representations are indeed those with highest causal effects on the target variables; and

  • We conduct experiments to demonstrate the effectiveness of the proposed framework for unsupervised domain adaptation by learning causal feature representations.

2 Related Work

We review research on deep visual domain adaptation and causal inference in domain adaptation and feature learning.

Deep visual domain adaptation. Despite the achievements of deep neural networks in feature learning, [4] and [3] show that the transferability of the features learned decreases by the last layer of the network. Deep domain adaptation addresses the issue by embedding domain adaptation into the pipeline of deep models and learning representations that are both predictive and domain invariant. This is achieved by using several different criteria. For example, [7], [13] and [6] leverage class labels as a guide for transferring knowledge across different domains. [14], [6] and [9] approach the problem by aligning the statistical distribution shifts between source and target domains. To compare the distributions of source and target domains, criteria such as maximum mean discrepancy (MMD) [14], [6], [8], [9], Kullback-Leibler(KL) divergence [15], and correlation alignment (CORAL) [16], [17] have been used. Another line of work focuses on adversarial-based domain adaptation which minimizes the distance between source and target domains through an adversarial objective of the domain discriminator, aiming to encourage domain confusion, including [7], [18], and [19] on visual deep domain adaptation.

Causal inference in domain adaptation. Recent work on shallow domain adaptation proposes to learn invariant features for domain adaptation over agnostic target domains. For instance, [11] propose a causal framework to identify the invariant features across different datasets and use them for prediction. [20] propose to find the causal features by exploring the invariance of conditional distribution of the target variable across different domains. [12] propose a causal approach to select domain invariant predictors among all predictors and use them to perform domain adaptation across unknown environments. However, they are all designed for shallow domain adaptation and choose useful predictors rather than learning them from the data.

Causal inference has been also utilized for learning visual features from the data. [21] proposes a visual causal feature learning framework which constructs causal variables from micro variables in observational data with minimum experimental effort.However, this work requires performing experiments.

3 Preliminaries of Domain Adaptation

In this work, given a source domain with labeled samples, we predict the labels for an unlabeled target domain by leveraging samples from the source domain to minimize prediction errors. Let P and Q denote the distributions of the source and target domains, respectively. P(X, Y) and Q(X, Y)

are the joint distributions of the inputs and outcomes for source and target domains. In general, these two joint distributions are different, i.e.,

. Following the traditional setting [5], we have two basic assumptions: (i) , which means the conditional distribution of the outcome given the data remains the same across different environments; and (ii)

, indicating that the difference between the joint distribution of the inputs and outcomes originates from the difference between the marginal distributions of the inputs. Moreover, we assume that the difference between marginal distributions of the inputs comes from a drift in the feature space of the problem (or covariate shift). Specifically, we study the problem of constructing a deep neural network that learns transferable representations, Z, for which the conditional probability of the outcomes remains the same across different domains and learns a classifier (

(.)) such that maps the learned representations to the outcome and the target loss is minimized.

4 Proposed Framework - DCDAN

We propose a Deep Causal Representation learning framework for unsupervised Domain Adaptation (DCDAN) to learn transferable feature representations for target domain prediction. A deep neural network combines feature extraction and a classifier which learns the highly correlated feature representations with the outcome. However, some of these correlations could be due to biases in the data, e.g., confounding bias. We aim to remove spurious correlations learned by the model by measuring the causal effects of the representations on the outcome. DCDAN consists of a regularization term which learns the balancing weights for the source data by balancing the distribution with respect to feature representations learned from the data. These weights are designed in a way to help the model capture the causal effects of the features on the target variable instead of their correlations. Moreover, our model includes a weighted loss function of deep neural net, where the weight of each sample comes from the regularization term and the loss function is in charge of learning predictive domain invariant features as well as a classifier which maps the learned representations to the output, or a causal mechanism. By embedding the sample weights of the learning component into the pipeline of the model and jointly learning these weights with the representations, we can benefit from the deep model to learn causal features that are both transferable and good predictors for the target. The framework is shown in Figure  

2.

Figure 2: An overview of DCDAN that learns causal representation of the data for prediction.

4.1 Balancing Sample Weights Regularizer

As discussed in the previous section, in order to learn the representations with causal effect on the outcome (i.e. causal feature representations) and reduce the confounding bias, we need to force the deep neural net to learn the causal relationships instead of the correlations. To do so, following [12]

, we reweight samples in the source domain with sample weights which enable us to capture the causal effect of each learned feature on the outcome variable by controlling the effect of all other features in the learned representation. Variable balancing techniques are often used in causal effect estimation from observational data where the distributions of the covariates are different between treatment and control groups due to the non-random assignment of the treatments to the units. In order to get an unbiased estimation of the causal effect of the treatment variables, one approach is to balance distributions of the treatment and control groups by applying balancing weights on the samples. One way to learn these weights is to characterize the distributions of the treatment and control groups by their moments and learn the weights W as following:

where T is the treatment variable and T = 1 and T = 0 represent the treatment and control groups respectively.

It is shown that by considering each feature learned in the set of feature representations as a treatment variable and learning weights to balance the distribution of the source data with respect to every learned feature representation, we can learn a new domain, in which only causal feature representations are correlated with the outcome. These new representations then can be used to estimate the causal contributions of source representations on the outcome. The sample weights can be learned by minimizing the below function:

where

denotes a vector of sample weights, Z = h(X) represents the feature representations extracted from the deep model,

is a vector of i-th feature representation of all samples, is the set of all features representations except the i-th ones,

refers to the Hadamard product and I is and indicator matrix that indicates whether the feature exists in data samples (i.e. the entry is equal to one, which means the samples belong to the treatment group) or does not exist (the entry is equal to zero and, which means data samples belong to the control group). In order to create the Indicator matrix, we binarize the representations by leveraging the methods proposed in

[22] for binarizing neural networks. In order to binarize the representations, two different possible approaches can be used. The first function is deterministic:

where is the binarizaed version of the real-valued variable x. This function is easy to implement and is shown to work well in practice. The second function is stochastic:

where is the "hard-sigmoid" function defined as:

This stochastic function makes the overlap assumption (1) more plausible than a deterministic function. Since the derivative of both binarization functions are almost zero everywhere during the back-propagation, following Hinton 2012’s lectures and [22], we use “straight-through estimator" for back propagation.

4.2 Deep Causal Domain Adaptation

We explore the idea of using a sample reweighting regularizer for learning transferable features in deep networks. Convolutional Neural Networks (CNN) are widely used deep frameworks in computer vision

[1] and achieved great performance in a variety of tasks. However, as discussed earlier, transferring these models to new domains, layers of the CNN model become less transferable at more task-specific layers. Therefore, it requires vast amounts of labeled data in the target domain to fine-tune the model and avoid over-fitting. Unsupervised domain adaptation frameworks such as [9] and [8] leverage unlabeled data in the target domain and learn the more transferable representations in the source domain to the target domain. They learn transferable feature representations highly correlated with the outcome. Relying on correlations limits their performance since correlations do not necessarily exist in other domains, thus, may not be transferable. If we find only transferable feature representations, the performance on the target domain can be further improved. We propose to learn causal representations of the input for which conditional distributions of the outcomes (a.k.a. causal mechanisms) remain the same across different domains even if the distributions of the inputs change [10]. this can be achieved by reweighting the learned representations with sample weights learned by a balancing sample regularizer. Reweighted samples play the role of a virtual target domain in which only representations with causal contributions on the outcome are correlated with the target and spurious correlations between two variables that do not truly exist and are due to the confounders are removed. By doing so, we force the model to learn features with highest causal contributions on the outcome.

We implement DCDAN on the Resnet-50 architecture [23]

which can be easily replaced with any other convolutional neural network framework. Resnet-50 is used as a backbone for a variety of tasks in computer vision such as deep transfer learning. It consists of 5 stages with convolutions and identity blocks. Each convolution and identity block consist of 3 convolution layers. Resnet-50 leverages the concept of skipping connections, which proposes to add the original input to the output of the convolution block while stacking the convolution layers together, to reduce the risk of vanishing the gradients during back-propagation. The empirical risk of CNN can be written as:

(1)

where J is the cross-entropy loss function, parameters of CNN, and the conditional probability that CNN assigns to label .

To learn invariant causal feature representations among different domains by using a balancing sample reweighting regularizer, we propose to jointly learn the sample weights from the data, reweight the representations by these these weights and minimize the loss of prediction for the the reweighted samples. This can be done by reweighting the loss of each sample with its corresponding weight and minimizing the weighted empirical loss of CNN and the balancing regularizer G simultaneously. This way, we can reduce the bias in the correlations learned by the model and learn the features with highest causal effect on the output that are also informative for prediction. Therefore, we embed the balancing sample reweighting regularizer into the CNN framework as:

(2)

where is the weighted cross-entropy loss function, representing the weighted empirical loss of CNN model, G the balancing sample reweighting regularizer, a vector of sample weights, ensures all the weights are non-negative,

tries to reduce the sample variance and

avoids all sample weights from being zero.

To optimize DCDAN, we update and W iteratively using mini-batch SGD. A detailed algorithm of the optimization framework can be found in the supplementary material.

4.3 Theoretical analysis

In this section, we explain the key assumption of our proposed method and provide some analysis on the reasons why the method learns the causal features (i.e. the features with highest causal contributions on the target variable). In order for our method to work, we need to make the overlap assumption, which is a common assumption in causal inference literature [12]. Overlap assumption ensures that for each data instance in the treatment group, a counterpart from control group can be found.

Assumption 1.

[Overlap] For any variable , where is the treatment variable for i-th sample in the data, for any in the dataset.

Proposition 1.

Feature representations learned by DCDAN have highest causal contributions on the outcome.

Proof of this proposition is in the supplementary material.

Postulate 1 (Independence of mechanism and input).

Following [10], we assume that the causal mechanism is "independent" of the distribution of the cause. In other words, contains no information about where E is the effect (i.e. outcome) and C is the cause. This indicates that changes in has no influence on the mechanism after it is learned.

Postulate 1 implies that once the causal mechanism (i.e. ) is learned, we can assume that it remains the same even when the distribution of the input (i.e. ) and therefore distribution of causes (i.e. ) changes. Therefore, we can address the covariate shift problem [24], where remains the same across different environment while changes, by learning the causal features from the input and a functions that maps those feature to the outcome.

5 Experiments

Figure 3: An example of samples in dataset constructed to perform (EQ2) and heat-map generated by DCDAN. Figure 3 shows a sample image from the data, Figure 3 shows the ground truth for causal features of figure 3 extracted from VQA-X dataset and figure 3 shows the heat-map generated by DCDAN for the causal feature representations

Our experiments are designed to evaluate the effectiveness of the proposed DCDAN with the following questions:

  • EQ1: How is DCDAN compared to existing unsupervised deep domain adaptation frameworks?

  • EQ2: Are the feature representations learned by DCDAN, causal features for predicting outcomes?

  • EQ3

    : How does varying the causal regularizer and other hyperparameters affect the classification performance of DCDAN?

We introduce the datasets and representative state-of-the-art deep domain adaptation frameworks, then compare the performance of DCDAN for an object recognition task to answer EQ1. For EQ2, we investigate the ability of DCDAN to learn and extract causal features from the data. For EQ3, we perform experiments by varying all hyperparametes of the model and report their performance.

5.1 Datasets

To answer EQ1, following the convention for domain adaptation study, we use Office-31 and Office-10+Caltech-10, two of the widely-adopted, publicly available benchmark datasets. Office-31 consists of 4,652 images within 31 categories collected from three distinct domains: Amazon (), which contains images downloaded from amazon.com, Webcam () and DSLR (), which are images taken by web camera and digital SLR camera in an office with different environment variation. For a comprehensive comparison, following [9], we evaluate the performance of all models by considering all possible pairs among the three domains: , , , , , and . Office-10 + Caltech-10 dataset consists of 10 common categories shared by the Office31 and Caltech-256 (). Thus, for all 4 domains (, , , and ), we conduct evaluations on all remaining possible pairs involving : , , , , , and .

Method Average
ResNet-50 69.685 97.610 99.405 71.485 63.080 61.448 77.118
DDC 77.987 96.981 100.000 81.526 65.246 64.004 80.957
DAN 82.000 97.000 100.000 83.000 66.000 65.000 82.166
DeepCoral 77.800 97.700 99.700 81.500 64.600 64.000 80.883
DANN 82.000 96.900 99.100 79.700 68.200 67.400 82.216
HAFN 82.900 98.100 99.600 83.700 69.700 68.100 83.683
DCDAN 81.000 99.000 100.000 86.000 69.000 70.000 84.166
Table 1: The Prediction Performance on Domain Adaptation on Office-31
Method Average
ResNet-50 86.375 89.671 87.978 93.423 93.559 93.631 90.772
DDC 91.184 89.670 89.586 94.989 95.932 96.815 93.029
DAN 91.000 89.000 87.000 95.000 96.000 96.000 92.333
DeepCoral 90.293 88.691 87.529 95.198 95.593 96.178 92.247
DANN 91.451 87.000 84.862 94.000 94.000 92.000 90.552
HAFN 90.115 91.629 95.302 91.718 95.254 92.356 92.729
DCDAN 92.000 90.000 91.000 94.000 96.000 96.000 93.166
Table 2: The Prediction Performance on Domain Adaptation on Office-10+Caltech-10

To answer EQ2, a dataset with causal features ground truths is needed. However, obtaining ground truths for causal features of objects is a difficult task since existing datasets for object recognition do not include the causal features of the targets and only contain the ground truths for the labels of the images in the datasets. Therefore, to answer EQ2, we construct a dataset with reliable ground truth for causal features corresponding to target variables by utilizing a subset of Visual Question Answering Explanation (VQA-X) dataset proposed in [25], where a set of images extracted from MSCOCO dataset111http://cocodataset.org along with a set of questions, answers, visual and textual explanations for questions are provided by human annotators. To construct the dataset for our study, we extract a subset of VQA-X dataset that contains only single objects with their corresponding labels from MSCOCO dataset. Our dataset consists of actual images, their labels and the visual explanations of the target variable which represents the causal feature representations of the target. These visual explanations are given by heatmaps provided by human experts. Figures 3 and 3 show one example of the data.

5.2 Compared Baseline Methods

In this section, we briefly introduce the representative state-of-the-art baseline methods for deep domain adaptation:

  • ResNet-50 [23]: It is a state-of-the-art convolution neural network model for image classification. It utilizes a deep residual neural network structure which introduces the identity shortcut connect to skip layers automatically to avoid overfitting for deep neural networks.

  • DDC [8]: DDC is a deep domain confusion model that aims to maximize the domain invariance by adding an adaptation layer in convolution neural networks with a single-kernel maximum mean discrepancies (MMD) regularization proposed in [26].

  • DAN [9]: DAN is a deep adaptive neural network model for unsupervised domain adaptation. It reduces the domain discrepancy via optimal multi-kernel selection for mean embedding matching.

  • Deep CORAL [16]

    : Deep CORAL is an unsupervised deep domain adaptation framework, which learns a non-linear transformation that aligns the second-order statistics between the source and target feature activations in deep neural networks.

  • DANN [18]: DANN is an adversarial representation learning framework in which one classifier aims to distinguish the learn source/target features while the another feature extractor tries to confuse the domain classifier. The minimax optimization is then solved through a gradient reversal layer.

  • HAFN  [27]: Hard Adaptive Feature Norm is a variantl of AFN, a non-parametric Adaptive Feature Norm framework for unsupervised domain adaptation, based on adapting feature norms of source and target domains to achieve symmetry over a large number of values.

5.3 Classification Performance Comparison

To answer EQ1

, we compare DCDAN with aforementioned representative methods. Following existing works on deep domain adaptation, we utilize accuracy as the evaluation metric. For baseline methods, we follow standard evaluation mechanism for unsupervised domain adaptation and use all source instances with labels and all target instances without labels 

[9]

. We implement all the baselines using PyTorch 

222https://pytorch.org/. In addition, we evaluate all compared approaches through grid search on the hyperparameter space, and report the best results. For MMD-based methods (i.e. DDC and DAN), we use Gaussian kernels. The experimental results are shown in Table 1 and Table 2. From the tables, we make the following observations:

(a) DCDAN
(b) Resnet-50
Figure 4: Heatmaps generated by DCDAN and Resnet-50 on a subset of VQA-X data.
  • Without access to data in the target domain, DCDAN still outperform the baselines in many cases for both datasets, which validates that causal feature representations are helpful for learning transferable features across domains.

  • DCDAN significantly outperforms Resnet-50, the only baseline that does not use any information from the target domain, which suggests that DCDAN can perform unsupervised domain adaptation.

(a) Accuracy vs.
(b) Accuracy vs.
(c) Accuracy vs.
Figure 5: Accuracy of DCDAN with varying hyperparameters , and on tasks and

5.4 Causal Feature Evaluation

In this subsection, to answer EQ2, we evaluate the performance of DCDAN for discovering causal features automatically from the data. It is worth mentioning that all of the baselines used in our first experiment are designed for improving the performance of image classification in the target domain and none of them are initially proposed to discover interpretable causal features. However, due to the nature of our approach, it is expected to be able to learn more interpretable features from the data by looking for the features that belong to the structure of the object rather than the context. In order to show the effectiveness of our model in learning causal features, we propose to run DCDAN and a pretrained Resnet-50 on a subset of VQA-X dataset [25] as described in "Datasets" section, extract their feature representations heatmaps and compare them to the visual ground truths provided in VQA-X.

In order to make a fair comparison, we fine-tune our model on a small subset of single-object images in Imagenet dataset

[28] and extract the feature representation heatmaps from the fine-tuned DCDAN and pre-trained Resnet-50 using the method proposed in [29]. Figure 3 shows an example generated heatmap by DCDAN. To compare the generated heatmaps, Following traditional settings, we use Rank correlarion as our evaluation metric. Following [30], to calculate the Rank Correlation, we first scale both the ground truths and heatmaps generated by the models to 14x14, rank the pixels according to their spatial attention and compute the correlation between two ranked lists. Our experiment shows that the rank correlation for DCDAN is 0.4501 whereas the rank correlation for the pre-trained Resnet-50 is 0.4077, which demonstrates the effectiveness of DCDAN on learning causal representations. Figure 6 shows an example of the generated heatmaps by DCDAN and pre-trained Resnet-50. As seen in Figure 3, the feature representation heatmaps generated by DCDAN are more focused on the causal features of the object, whereas the features learned by Resnet-50 can contain both causal and context features.

5.5 Parameter Analysis

To answer (EQ3), we investigate the effect of hyperparameters of the model on the performance (i.e. accuracy of the model) and report the results in Figure 5. The performance is reported on tasks and . To be more specific, Figure 5(a) illustrates the performance of the balancing sample weights regularizer () on classification performance of DCDAN as it varies in the . We observe that the accuracy of DCDAN first increases and then decreases as varies in the mentioned range. This further confirms that a good trade-off between causal loss and classification loss can result in learning more transferable features. Figure 5(b) shows the effect of the hyperparameter which controls the sample variance, on the accuracy of DCDAN. We also report the effect of which is designed to prevent all sample weights from being zero in Figure 5(c). To measure the effect of on the performance.

6 Conclusion

In this paper, we propose DCDAN, a novel Deep Causal Representation learning framework for unsupervised domain adaptation to generate more transferable feature representations by extracting causal feature representations instead of only considering the correlations. In order to learn the causal representations a virtual target domain consists of reweighted samples is generated. These sample weights are learned concurrently with the feature representations and in the pipeline of the deep model. Experiments demonstrate the effectiveness of our model for both classification performance and learning causal feature representations.

7 Proof of Proposition 1

Feature representations learned by DCDAN are the features with highest causal contributions on the outcome.

Proof.

Following [12], it can be proved that 1) reweighted feature representations in the simulated target domain are independent of each other and 2) In the new target domain, only causal feature representations are correlated with the outcome and the possible spurious correlations between the context features and outcome vanish. To be more specific, we have:

where T denotes the target domain, C represents all causal features, S denotes all spurious variables and Y is the outcome variable.

This shows that by simulating a target domain using the balancing sample weights, correlation between the simulated representations for virtual target domain and the outcome variable can be measured to estimate the unbiased causal contribution of representations on the target and using these virtual representations in the pipeline of the deep model, we can learn the feature representations with highest causal contributions on the data instead of the using correaltions. In other words, features with high spurious correlations are not learned. ∎

8 Optimization of DCDAN

Algorithm 1 explains the optimization procedure of the proposed framework. DCDAN utilizes an alternating optimization approach to solve the optimization problem defined in Eq (2). To be more specific, DCDAN updates

and W iteratively using mini-batch stochastic gradient descent.

(a) Heatmap generated for an instance of pizza class by DCDAN
(b) Heatmap generated for an instance of bird class by DCDAN
(c) Heatmap generated for an instance of dog class by DCDAN
(d) Heatmap generated for an instance of pizza class by Resnet-50
(e) Heatmap generated for an instance of bird class by Resnet-50
(f) Heatmap generated for an instance of dog class by Resnet-50
Figure 6: Heatmaps generated by DCDAN and Resnet-50 on a subset of VQA-X data. First row shows the feature representation heatmaps generated by our propose DCDAN and sencond row shows the heatmaps generated by pre-trained Resnet-50
Matrix of input images X and ground truth labels Y
Updated parameters of the model W and  
Initialize the iteration variable
Initialize parameters and
Calculate the loss function with (, ) as stated in Eq (2) in the paper .
repeat
     
      update using Stochastic Gradient Descent while W is fixed
      update W using Stochastic Gradient Descent while is fixed
     Calculate loss function with
until Loss function converges or maximum iteration is reached
return W,
Algorithm 1 Deep Causal Domain Adaptation Network algorithm

9 Additional Case Studies for Learnig Causal Feature Representations

In this section, we provide more case studies for "Causal Feature Evaluation" section of the paper. Both DCDAN and Resnet-50 are trained according to the settings explained in "Causal Feature Evaluation" section on a subset of VQA-X dataset [25] as described in "Datasets" section. Figure 6 shows some examples of the heatmaps generated by both our proposed framework and Resnet-50. Heatmaps generated by DCDAN are more focused on the regions which blong to the structure of the outcome (i.e. causal features), whereas the features learned by Resnet-50 belong to both the structure as well as the context.

References